@@ -189,12 +189,11 @@ def forward(
189189 encoder_hidden_states_mask : torch .Tensor = None ,
190190 timestep : torch .LongTensor = None ,
191191 img_shapes : Optional [List [Tuple [int , int , int ]]] = None ,
192- txt_seq_lens : Optional [List [int ]] = None ,
193192 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
194193 return_dict : bool = True ,
195194 ) -> Union [torch .FloatTensor , Transformer2DModelOutput ]:
196195 """
197- The [`FluxTransformer2DModel `] forward method.
196+ The [`QwenImageControlNetModel `] forward method.
198197
199198 Args:
200199 hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
@@ -205,26 +204,24 @@ def forward(
205204 The scale factor for ControlNet outputs.
206205 encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
207206 Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
208- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
209- from the embeddings of input conditions.
207+ encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
208+ Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
209+ Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
210+ (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
210211 timestep ( `torch.LongTensor`):
211212 Used to indicate denoising step.
212- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
213- A list of tensors that if specified are added to the residuals of transformer blocks.
214- txt_seq_lens (`List[int]`, *optional*):
215- Optional text sequence lengths. If omitted, or shorter than the encoder hidden states length, the model
216- derives the length from the encoder hidden states (or their mask).
213+ img_shapes (`List[Tuple[int, int, int]]`, *optional*):
214+ Image shapes for RoPE computation.
217215 joint_attention_kwargs (`dict`, *optional*):
218216 A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
219217 `self.processor` in
220218 [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
221219 return_dict (`bool`, *optional*, defaults to `True`):
222- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
223- tuple.
220+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
224221
225222 Returns:
226- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput `] is returned, otherwise a
227- `tuple` where the first element is the sample tensor .
223+ If `return_dict` is True, a [`~models.controlnet.ControlNetOutput `] is returned, otherwise a `tuple` where
224+ the first element is the controlnet block samples .
228225 """
229226 if joint_attention_kwargs is not None :
230227 joint_attention_kwargs = joint_attention_kwargs .copy ()
@@ -247,13 +244,9 @@ def forward(
247244
248245 temb = self .time_text_embed (timestep , hidden_states )
249246
250- batch_size , text_seq_len = encoder_hidden_states .shape [:2 ]
251- if txt_seq_lens is not None :
252- if len (txt_seq_lens ) != batch_size :
253- raise ValueError (f"`txt_seq_lens` must have length { batch_size } , but got { len (txt_seq_lens )} instead." )
254- text_seq_len = max (text_seq_len , max (txt_seq_lens ))
255- elif encoder_hidden_states_mask is not None :
256- text_seq_len = max (text_seq_len , int (encoder_hidden_states_mask .sum (dim = 1 ).max ().item ()))
247+ # Use the encoder_hidden_states sequence length for RoPE computation
248+ # The mask is used for attention masking in the attention processor
249+ _ , text_seq_len = encoder_hidden_states .shape [:2 ]
257250
258251 image_rotary_emb = self .pos_embed (img_shapes , text_seq_len , device = hidden_states .device )
259252
@@ -332,7 +325,6 @@ def forward(
332325 encoder_hidden_states_mask : torch .Tensor = None ,
333326 timestep : torch .LongTensor = None ,
334327 img_shapes : Optional [List [Tuple [int , int , int ]]] = None ,
335- txt_seq_lens : Optional [List [int ]] = None ,
336328 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
337329 return_dict : bool = True ,
338330 ) -> Union [QwenImageControlNetOutput , Tuple ]:
@@ -350,7 +342,6 @@ def forward(
350342 encoder_hidden_states_mask = encoder_hidden_states_mask ,
351343 timestep = timestep ,
352344 img_shapes = img_shapes ,
353- txt_seq_lens = txt_seq_lens ,
354345 joint_attention_kwargs = joint_attention_kwargs ,
355346 return_dict = return_dict ,
356347 )
0 commit comments