Skip to content

Commit 2ef38e2

Browse files
committed
fix compile
1 parent b5b6342 commit 2ef38e2

File tree

2 files changed

+121
-36
lines changed

2 files changed

+121
-36
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,7 @@ def compute_text_seq_len_from_mask(
165165
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
166166
has_active = encoder_hidden_states_mask.any(dim=1)
167167
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
168-
169-
# For RoPE, we use the full text_seq_len (since per_sample_len.max() <= text_seq_len always)
170-
# Keep as tensor to avoid graph breaks in torch.compile
171-
rope_text_seq_len = torch.tensor(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
172-
173-
return rope_text_seq_len, per_sample_len, encoder_hidden_states_mask
168+
return text_seq_len, per_sample_len, encoder_hidden_states_mask
174169

175170

176171
class QwenTimestepProjEmbeddings(nn.Module):
@@ -271,10 +266,6 @@ def forward(
271266
if max_txt_seq_len is None:
272267
raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.")
273268

274-
# Move to device unconditionally to avoid graph breaks in torch.compile
275-
self.pos_freqs = self.pos_freqs.to(device)
276-
self.neg_freqs = self.neg_freqs.to(device)
277-
278269
# Validate batch inference with variable-sized images
279270
if isinstance(video_fhw, list) and len(video_fhw) > 1:
280271
# Check if all instances have the same size
@@ -297,25 +288,29 @@ def forward(
297288
for idx, fhw in enumerate(video_fhw):
298289
frame, height, width = fhw
299290
# RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
300-
video_freq = self._compute_video_freqs(frame, height, width, idx)
301-
video_freq = video_freq.to(device)
291+
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
302292
vid_freqs.append(video_freq)
303293

304294
if self.scale_rope:
305295
max_vid_index = max(height // 2, width // 2, max_vid_index)
306296
else:
307297
max_vid_index = max(height, width, max_vid_index)
308298

309-
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_txt_seq_len, ...]
299+
max_txt_seq_len_int = int(max_txt_seq_len)
300+
# Create device-specific copy for text freqs without modifying self.pos_freqs
301+
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
310302
vid_freqs = torch.cat(vid_freqs, dim=0)
311303

312304
return vid_freqs, txt_freqs
313305

314306
@functools.lru_cache(maxsize=128)
315-
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
307+
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None) -> torch.Tensor:
316308
seq_lens = frame * height * width
317-
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
318-
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
309+
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
310+
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
311+
312+
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
313+
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
319314

320315
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
321316
if self.scale_rope:
@@ -384,10 +379,6 @@ def forward(
384379
device: (`torch.device`, *optional*):
385380
The device on which to perform the RoPE computation.
386381
"""
387-
# Move to device unconditionally to avoid graph breaks in torch.compile
388-
self.pos_freqs = self.pos_freqs.to(device)
389-
self.neg_freqs = self.neg_freqs.to(device)
390-
391382
# Validate batch inference with variable-sized images
392383
# In Layer3DRope, the outer list represents batch, inner list/tuple represents layers
393384
if isinstance(video_fhw, list) and len(video_fhw) > 1:
@@ -412,11 +403,10 @@ def forward(
412403
for idx, fhw in enumerate(video_fhw):
413404
frame, height, width = fhw
414405
if idx != layer_num:
415-
video_freq = self._compute_video_freqs(frame, height, width, idx)
406+
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
416407
else:
417408
### For the condition image, we set the layer index to -1
418-
video_freq = self._compute_condition_freqs(frame, height, width)
419-
video_freq = video_freq.to(device)
409+
video_freq = self._compute_condition_freqs(frame, height, width, device)
420410
vid_freqs.append(video_freq)
421411

422412
if self.scale_rope:
@@ -425,16 +415,21 @@ def forward(
425415
max_vid_index = max(height, width, max_vid_index)
426416

427417
max_vid_index = max(max_vid_index, layer_num)
428-
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_txt_seq_len, ...]
418+
max_txt_seq_len_int = int(max_txt_seq_len)
419+
# Create device-specific copy for text freqs without modifying self.pos_freqs
420+
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
429421
vid_freqs = torch.cat(vid_freqs, dim=0)
430422

431423
return vid_freqs, txt_freqs
432424

433425
@functools.lru_cache(maxsize=None)
434-
def _compute_video_freqs(self, frame, height, width, idx=0):
426+
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
435427
seq_lens = frame * height * width
436-
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
437-
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
428+
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
429+
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
430+
431+
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
432+
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
438433

439434
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
440435
if self.scale_rope:
@@ -450,10 +445,13 @@ def _compute_video_freqs(self, frame, height, width, idx=0):
450445
return freqs.clone().contiguous()
451446

452447
@functools.lru_cache(maxsize=None)
453-
def _compute_condition_freqs(self, frame, height, width):
448+
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
454449
seq_lens = frame * height * width
455-
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
456-
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
450+
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
451+
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
452+
453+
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
454+
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
457455

458456
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
459457
if self.scale_rope:
@@ -911,8 +909,8 @@ def forward(
911909
"txt_seq_lens",
912910
"0.37.0",
913911
"Passing `txt_seq_lens` is deprecated and will be removed in version 0.37.0. "
914-
"Please use `txt_seq_len` instead (singular, not plural). "
915-
"The new parameter accepts a single int or tensor value instead of a list.",
912+
"Please use `encoder_hidden_states_mask` instead. "
913+
"The mask-based approach is more flexible and supports variable-length sequences.",
916914
standard_warn=False,
917915
)
918916
if attention_kwargs is not None:

tests/models/transformers/test_models_transformer_qwenimage.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,8 @@ def test_infers_text_seq_len_from_mask(self):
103103
inputs["encoder_hidden_states"], encoder_hidden_states_mask
104104
)
105105

106-
# Verify rope_text_seq_len is returned as a tensor (for torch.compile compatibility)
107-
self.assertIsInstance(rope_text_seq_len, torch.Tensor)
108-
self.assertEqual(rope_text_seq_len.ndim, 0) # Should be scalar tensor
106+
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
107+
self.assertIsInstance(rope_text_seq_len, int)
109108

110109
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
111110
self.assertIsInstance(per_sample_len, torch.Tensor)
@@ -116,7 +115,7 @@ def test_infers_text_seq_len_from_mask(self):
116115
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
117116

118117
# Verify rope_text_seq_len is at least the sequence length
119-
self.assertGreaterEqual(int(rope_text_seq_len.item()), inputs["encoder_hidden_states"].shape[1])
118+
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
120119

121120
# Test 2: Verify model runs successfully with inferred values
122121
inputs["encoder_hidden_states_mask"] = normalized_mask
@@ -142,6 +141,7 @@ def test_infers_text_seq_len_from_mask(self):
142141
inputs["encoder_hidden_states"], None
143142
)
144143
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
144+
self.assertIsInstance(rope_text_seq_len_none, int)
145145
self.assertIsNone(per_sample_len_none)
146146
self.assertIsNone(normalized_mask_none)
147147

@@ -162,6 +162,7 @@ def test_non_contiguous_attention_mask(self):
162162
)
163163
self.assertEqual(int(per_sample_len.max().item()), 5)
164164
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
165+
self.assertIsInstance(inferred_rope_len, int)
165166
self.assertTrue(normalized_mask.dtype == torch.bool)
166167

167168
inputs["encoder_hidden_states_mask"] = normalized_mask
@@ -171,6 +172,92 @@ def test_non_contiguous_attention_mask(self):
171172

172173
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
173174

175+
def test_txt_seq_lens_deprecation(self):
176+
"""Test that passing txt_seq_lens raises a deprecation warning."""
177+
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
178+
model = self.model_class(**init_dict).to(torch_device)
179+
180+
# Prepare inputs with txt_seq_lens (deprecated parameter)
181+
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
182+
183+
# Remove encoder_hidden_states_mask to use the deprecated path
184+
inputs_with_deprecated = inputs.copy()
185+
inputs_with_deprecated.pop("encoder_hidden_states_mask")
186+
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
187+
188+
# Test that deprecation warning is raised
189+
with self.assertWarns(FutureWarning) as warning_context:
190+
with torch.no_grad():
191+
output = model(**inputs_with_deprecated)
192+
193+
# Verify the warning message mentions the deprecation
194+
warning_message = str(warning_context.warning)
195+
self.assertIn("txt_seq_lens", warning_message)
196+
self.assertIn("deprecated", warning_message)
197+
self.assertIn("encoder_hidden_states_mask", warning_message)
198+
199+
# Verify the model still works correctly despite the deprecation
200+
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
201+
202+
def test_layered_model_with_mask(self):
203+
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
204+
# Create layered model config
205+
init_dict = {
206+
"patch_size": 2,
207+
"in_channels": 16,
208+
"out_channels": 16,
209+
"num_layers": 2,
210+
"attention_head_dim": 128,
211+
"num_attention_heads": 4,
212+
"joint_attention_dim": 16,
213+
"use_layer3d_rope": True, # Enable layered RoPE
214+
}
215+
216+
model = self.model_class(**init_dict).to(torch_device)
217+
218+
# Verify the model uses QwenEmbedLayer3DRope
219+
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
220+
221+
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
222+
223+
# Test single generation with layered structure
224+
batch_size = 1
225+
text_seq_len = 7
226+
img_h, img_w = 4, 4
227+
layers = 4
228+
229+
# For layered model: (layers + 1) because we have N layers + 1 combined image
230+
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
231+
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
232+
233+
# Create mask with some padding
234+
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
235+
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
236+
237+
timestep = torch.tensor([1.0]).to(torch_device)
238+
239+
# Layer structure: 4 layers + 1 condition image
240+
img_shapes = [
241+
[
242+
(1, img_h, img_w), # layer 0
243+
(1, img_h, img_w), # layer 1
244+
(1, img_h, img_w), # layer 2
245+
(1, img_h, img_w), # layer 3
246+
(1, img_h, img_w), # condition image (last one gets special treatment)
247+
]
248+
]
249+
250+
with torch.no_grad():
251+
output = model(
252+
hidden_states=hidden_states,
253+
encoder_hidden_states=encoder_hidden_states,
254+
encoder_hidden_states_mask=encoder_hidden_states_mask,
255+
timestep=timestep,
256+
img_shapes=img_shapes,
257+
)
258+
259+
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
260+
174261

175262
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
176263
model_class = QwenImageTransformer2DModel

0 commit comments

Comments
 (0)