@@ -133,10 +133,10 @@ def __call__(
133133 if attn .parallel_proj_in :
134134 hidden_states = attn .to_qkv_mlp_proj (hidden_states )
135135 qkv , mlp_hidden_states = torch .split (
136- hidden_states , [3 * attn .inner_dim , attn .mlp_hidden_dim * attn .mlp_mult_factor ]
136+ hidden_states , [3 * attn .inner_dim , attn .mlp_hidden_dim * attn .mlp_mult_factor ], dim = - 1
137137 )
138138 query , key , value = qkv .chunk (3 , dim = - 1 )
139- mlp_hidden_states = self .mlp_act_fn (mlp_hidden_states )
139+ mlp_hidden_states = attn .mlp_act_fn (mlp_hidden_states )
140140
141141 # Get encoder QKV, if available
142142 encoder_query = encoder_key = encoder_value = None
@@ -423,6 +423,7 @@ def forward(
423423 ) -> Tuple [torch .Tensor , torch .Tensor ]:
424424 joint_attention_kwargs = joint_attention_kwargs or {}
425425
426+ # Modulation parameters shape: [1, 1, self.dim]
426427 (shift_msa , scale_msa , gate_msa ), (shift_mlp , scale_mlp , gate_mlp ) = temb_mod_params_img
427428 (c_shift_msa , c_scale_msa , c_gate_msa ), (c_shift_mlp , c_scale_mlp , c_gate_mlp ) = temb_mod_params_txt
428429
@@ -448,27 +449,27 @@ def forward(
448449 attn_output , context_attn_output , ip_attn_output = attention_outputs
449450
450451 # Process attention outputs for the image stream (`hidden_states`).
451- attn_output = gate_msa . unsqueeze ( 1 ) * attn_output
452+ attn_output = gate_msa * attn_output
452453 hidden_states = hidden_states + attn_output
453454
454455 norm_hidden_states = self .norm2 (hidden_states )
455- norm_hidden_states = norm_hidden_states * (1 + scale_mlp [:, None ] ) + shift_mlp [:, None ]
456+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp ) + shift_mlp
456457
457458 ff_output = self .ff (norm_hidden_states )
458- hidden_states = hidden_states + gate_mlp . unsqueeze ( 1 ) * ff_output
459+ hidden_states = hidden_states + gate_mlp * ff_output
459460
460461 if len (attention_outputs ) == 3 :
461462 hidden_states = hidden_states + ip_attn_output
462463
463464 # Process attention outputs for the text stream (`encoder_hidden_states`).
464- context_attn_output = c_gate_msa . unsqueeze ( 1 ) * context_attn_output
465+ context_attn_output = c_gate_msa * context_attn_output
465466 encoder_hidden_states = encoder_hidden_states + context_attn_output
466467
467468 norm_encoder_hidden_states = self .norm2_context (encoder_hidden_states )
468- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp [:, None ] ) + c_shift_mlp [:, None ]
469+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp ) + c_shift_mlp
469470
470471 context_ff_output = self .ff_context (norm_encoder_hidden_states )
471- encoder_hidden_states = encoder_hidden_states + c_gate_mlp . unsqueeze ( 1 ) * context_ff_output
472+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
472473 if encoder_hidden_states .dtype == torch .float16 :
473474 encoder_hidden_states = encoder_hidden_states .clip (- 65504 , 65504 )
474475
@@ -483,6 +484,7 @@ def __init__(self, theta: int, axes_dim: List[int]):
483484 self .axes_dim = axes_dim
484485
485486 def forward (self , ids : torch .Tensor ) -> torch .Tensor :
487+ # Expected ids shape: [S, len(self.axes_dim)]
486488 cos_out = []
487489 sin_out = []
488490 pos = ids .float ()
@@ -493,7 +495,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
493495 for i in range (len (self .axes_dim )):
494496 cos , sin = get_1d_rotary_pos_embed (
495497 self .axes_dim [i ],
496- pos [: , i ],
498+ pos [... , i ],
497499 theta = self .theta ,
498500 repeat_interleave_real = True ,
499501 use_real = True ,
@@ -736,6 +738,8 @@ def forward(
736738 "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
737739 )
738740
741+ num_txt_tokens = encoder_hidden_states .shape [1 ]
742+
739743 # 1. Calculate timestep embedding and modulation parameters
740744 timestep = timestep .to (hidden_states .dtype ) * 1000
741745 guidance = guidance .to (hidden_states .dtype ) * 1000
@@ -751,6 +755,13 @@ def forward(
751755 encoder_hidden_states = self .context_embedder (encoder_hidden_states )
752756
753757 # 3. Calculate RoPE embeddings from image and text tokens
758+ # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
759+ # text prompts of differents lengths. Is this a use case we want to support?
760+ if img_ids .ndim == 3 :
761+ img_ids = img_ids [0 ]
762+ if txt_ids .ndim == 3 :
763+ txt_ids = txt_ids [0 ]
764+
754765 if is_torch_npu_available ():
755766 freqs_cos_image , freqs_sin_image = self .pos_embed (img_ids .cpu ())
756767 image_rotary_emb = (freqs_cos_image .npu (), freqs_sin_image .npu ())
@@ -760,8 +771,8 @@ def forward(
760771 image_rotary_emb = self .pos_embed (img_ids )
761772 text_rotary_emb = self .pos_embed (txt_ids )
762773 concat_rotary_emb = (
763- torch .cat ([text_rotary_emb [0 ], image_rotary_emb [0 ]], dim = 2 ),
764- torch .cat ([text_rotary_emb [1 ], image_rotary_emb [1 ]], dim = 2 ),
774+ torch .cat ([text_rotary_emb [0 ], image_rotary_emb [0 ]], dim = 0 ),
775+ torch .cat ([text_rotary_emb [1 ], image_rotary_emb [1 ]], dim = 0 ),
765776 )
766777
767778 if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs :
@@ -790,26 +801,30 @@ def forward(
790801 image_rotary_emb = concat_rotary_emb ,
791802 joint_attention_kwargs = joint_attention_kwargs ,
792803 )
804+ # Concatenate text and image streams for single-block inference
805+ hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
793806
794807 # 5. Single Stream Transformer Blocks
795808 for index_block , block in enumerate (self .single_transformer_blocks ):
796809 if torch .is_grad_enabled () and self .gradient_checkpointing :
797- encoder_hidden_states , hidden_states = self ._gradient_checkpointing_func (
810+ hidden_states = self ._gradient_checkpointing_func (
798811 block ,
799812 hidden_states ,
800- encoder_hidden_states ,
813+ None ,
801814 single_stream_mod ,
802815 concat_rotary_emb ,
803816 joint_attention_kwargs ,
804817 )
805818 else :
806- encoder_hidden_states , hidden_states = block (
819+ hidden_states = block (
807820 hidden_states = hidden_states ,
808- encoder_hidden_states = encoder_hidden_states ,
821+ encoder_hidden_states = None ,
809822 temb_mod_params = single_stream_mod ,
810823 image_rotary_emb = concat_rotary_emb ,
811824 joint_attention_kwargs = joint_attention_kwargs ,
812825 )
826+ # Remove text tokens from concatenated stream
827+ hidden_states = hidden_states [:, num_txt_tokens :, ...]
813828
814829 # 6. Output layers
815830 hidden_states = self .norm_out (hidden_states , temb )
0 commit comments