Skip to content

Commit ac5ac24

Browse files
committed
remove txt_seq_lens and use bool mask
1 parent 88cee8b commit ac5ac24

File tree

4 files changed

+59
-42
lines changed

4 files changed

+59
-42
lines changed

examples/dreambooth/train_dreambooth_lora_qwen_image.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,14 +1513,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15131513
height=model_input.shape[3],
15141514
width=model_input.shape[4],
15151515
)
1516-
print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
15171516
model_pred = transformer(
15181517
hidden_states=packed_noisy_model_input,
15191518
encoder_hidden_states=prompt_embeds,
15201519
encoder_hidden_states_mask=prompt_embeds_mask,
15211520
timestep=timesteps / 1000,
15221521
img_shapes=img_shapes,
1523-
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
15241522
return_dict=False,
15251523
)[0]
15261524
model_pred = QwenImagePipeline._unpack_latents(

src/diffusers/models/controlnets/controlnet_qwenimage.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,11 @@ def forward(
189189
encoder_hidden_states_mask: torch.Tensor = None,
190190
timestep: torch.LongTensor = None,
191191
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
192-
txt_seq_lens: Optional[List[int]] = None,
193192
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
194193
return_dict: bool = True,
195194
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
196195
"""
197-
The [`FluxTransformer2DModel`] forward method.
196+
The [`QwenImageControlNetModel`] forward method.
198197
199198
Args:
200199
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
@@ -205,26 +204,24 @@ def forward(
205204
The scale factor for ControlNet outputs.
206205
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
207206
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
208-
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
209-
from the embeddings of input conditions.
207+
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
208+
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
209+
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
210+
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
210211
timestep ( `torch.LongTensor`):
211212
Used to indicate denoising step.
212-
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
213-
A list of tensors that if specified are added to the residuals of transformer blocks.
214-
txt_seq_lens (`List[int]`, *optional*):
215-
Optional text sequence lengths. If omitted, or shorter than the encoder hidden states length, the model
216-
derives the length from the encoder hidden states (or their mask).
213+
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
214+
Image shapes for RoPE computation.
217215
joint_attention_kwargs (`dict`, *optional*):
218216
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
219217
`self.processor` in
220218
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
221219
return_dict (`bool`, *optional*, defaults to `True`):
222-
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
223-
tuple.
220+
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
224221
225222
Returns:
226-
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
227-
`tuple` where the first element is the sample tensor.
223+
If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where
224+
the first element is the controlnet block samples.
228225
"""
229226
if joint_attention_kwargs is not None:
230227
joint_attention_kwargs = joint_attention_kwargs.copy()
@@ -247,13 +244,9 @@ def forward(
247244

248245
temb = self.time_text_embed(timestep, hidden_states)
249246

250-
batch_size, text_seq_len = encoder_hidden_states.shape[:2]
251-
if txt_seq_lens is not None:
252-
if len(txt_seq_lens) != batch_size:
253-
raise ValueError(f"`txt_seq_lens` must have length {batch_size}, but got {len(txt_seq_lens)} instead.")
254-
text_seq_len = max(text_seq_len, max(txt_seq_lens))
255-
elif encoder_hidden_states_mask is not None:
256-
text_seq_len = max(text_seq_len, int(encoder_hidden_states_mask.sum(dim=1).max().item()))
247+
# Use the encoder_hidden_states sequence length for RoPE computation
248+
# The mask is used for attention masking in the attention processor
249+
_, text_seq_len = encoder_hidden_states.shape[:2]
257250

258251
image_rotary_emb = self.pos_embed(img_shapes, text_seq_len, device=hidden_states.device)
259252

@@ -332,7 +325,6 @@ def forward(
332325
encoder_hidden_states_mask: torch.Tensor = None,
333326
timestep: torch.LongTensor = None,
334327
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
335-
txt_seq_lens: Optional[List[int]] = None,
336328
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
337329
return_dict: bool = True,
338330
) -> Union[QwenImageControlNetOutput, Tuple]:
@@ -350,7 +342,6 @@ def forward(
350342
encoder_hidden_states_mask=encoder_hidden_states_mask,
351343
timestep=timestep,
352344
img_shapes=img_shapes,
353-
txt_seq_lens=txt_seq_lens,
354345
joint_attention_kwargs=joint_attention_kwargs,
355346
return_dict=return_dict,
356347
)

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,8 @@ def __call__(
330330
joint_value = torch.cat([txt_value, img_value], dim=1)
331331

332332
# If an encoder_hidden_states_mask is provided, turn it into a broadcastable attention mask.
333+
# The encoder_hidden_states_mask is expected to have 1.0 for valid tokens and 0.0 for padding.
334+
# We convert it to a boolean mask where True means "attend" and False means "mask out" (don't attend).
333335
if encoder_hidden_states_mask is not None and attention_mask is None:
334336
batch_size, image_seq_len = hidden_states.shape[:2]
335337
text_seq_len = encoder_hidden_states.shape[1]
@@ -345,7 +347,9 @@ def __call__(
345347
f"must match encoder_hidden_states sequence length ({text_seq_len})."
346348
)
347349

348-
text_attention_mask = encoder_hidden_states_mask.to(dtype=torch.bool, device=hidden_states.device)
350+
# Convert mask to boolean: 1/1.0 -> True (attend), 0/0.0 -> False (don't attend)
351+
# This is the correct semantics for PyTorch's scaled_dot_product_attention with boolean masks.
352+
text_attention_mask = encoder_hidden_states_mask.bool()
349353
image_attention_mask = torch.ones(
350354
(batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device
351355
)
@@ -592,7 +596,6 @@ def forward(
592596
encoder_hidden_states_mask: torch.Tensor = None,
593597
timestep: torch.LongTensor = None,
594598
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
595-
txt_seq_lens: Optional[List[int]] = None,
596599
guidance: torch.Tensor = None, # TODO: this should probably be removed
597600
attention_kwargs: Optional[Dict[str, Any]] = None,
598601
controlnet_block_samples=None,
@@ -606,17 +609,22 @@ def forward(
606609
Input `hidden_states`.
607610
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
608611
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
609-
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
610-
Mask of the input conditions.
612+
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
613+
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
614+
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
615+
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
611616
timestep ( `torch.LongTensor`):
612617
Used to indicate denoising step.
613-
txt_seq_lens (`List[int]`, *optional*):
614-
Optional text sequence lengths. If not provided, or if any provided values are shorter than the encoder
615-
hidden states length, the model falls back to the encoder hidden states length.
618+
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
619+
Image shapes for RoPE computation.
620+
guidance (`torch.Tensor`, *optional*):
621+
Guidance tensor for conditional generation.
616622
attention_kwargs (`dict`, *optional*):
617623
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
618624
`self.processor` in
619625
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
626+
controlnet_block_samples (*optional*):
627+
ControlNet block samples to add to the transformer blocks.
620628
return_dict (`bool`, *optional*, defaults to `True`):
621629
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
622630
tuple.
@@ -646,13 +654,9 @@ def forward(
646654
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
647655
encoder_hidden_states = self.txt_in(encoder_hidden_states)
648656

649-
batch_size, text_seq_len = encoder_hidden_states.shape[:2]
650-
if txt_seq_lens is not None:
651-
if len(txt_seq_lens) != batch_size:
652-
raise ValueError(f"`txt_seq_lens` must have length {batch_size}, but got {len(txt_seq_lens)} instead.")
653-
text_seq_len = max(text_seq_len, max(txt_seq_lens))
654-
elif encoder_hidden_states_mask is not None:
655-
text_seq_len = max(text_seq_len, int(encoder_hidden_states_mask.sum(dim=1).max().item()))
657+
# Use the encoder_hidden_states sequence length for RoPE computation
658+
# The mask is used for attention masking in the attention processor
659+
_, text_seq_len = encoder_hidden_states.shape[:2]
656660

657661
if guidance is not None:
658662
guidance = guidance.to(hidden_states.dtype) * 1000

tests/models/transformers/test_models_transformer_qwenimage.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,20 @@ def test_gradient_checkpointing_is_applied(self):
9090
expected_set = {"QwenImageTransformer2DModel"}
9191
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
9292

93-
def test_accepts_short_txt_seq_lens(self):
93+
def test_infers_text_seq_len_from_mask(self):
9494
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
9595
model = self.model_class(**init_dict).to(torch_device)
9696

97-
# Provide a deliberately short txt_seq_lens to ensure the model falls back to the embedding length.
98-
inputs["txt_seq_lens"] = [2] * inputs["encoder_hidden_states"].shape[0]
97+
# Create a mask with only 2 valid tokens (rest are padding)
98+
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
99+
encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
100+
101+
inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask
99102

100103
with torch.no_grad():
101104
output = model(**inputs)
102105

106+
# The model should infer text_seq_len=2 from the mask for RoPE computation
103107
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
104108

105109
def test_builds_attention_mask_from_encoder_mask(self):
@@ -111,13 +115,33 @@ def test_builds_attention_mask_from_encoder_mask(self):
111115
encoder_hidden_states_mask[:, -2:] = 0
112116

113117
inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask
114-
inputs.pop("txt_seq_lens", None)
115118

116119
with torch.no_grad():
117120
output = model(**inputs)
118121

119122
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
120123

124+
def test_non_contiguous_attention_mask(self):
125+
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
126+
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
127+
model = self.model_class(**init_dict).to(torch_device)
128+
129+
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
130+
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
131+
# Pattern: [True, False, True, False, True, False, False]
132+
encoder_hidden_states_mask[:, 1] = 0
133+
encoder_hidden_states_mask[:, 3] = 0
134+
encoder_hidden_states_mask[:, 5:] = 0
135+
136+
inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask
137+
138+
with torch.no_grad():
139+
output = model(**inputs)
140+
141+
# The model should handle non-contiguous masks correctly
142+
# RoPE uses the full sequence length, attention masking handles the pattern
143+
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
144+
121145

122146
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
123147
model_class = QwenImageTransformer2DModel

0 commit comments

Comments
 (0)