Skip to content

Commit 35cb2c8

Browse files
committed
Apply ruff formatting to QwenImage warning implementation
- Fix whitespace and string quote consistency - Add trailing commas where appropriate - Clean up formatting per diffusers code standards
1 parent 39462a4 commit 35cb2c8

File tree

2 files changed

+45
-39
lines changed

2 files changed

+45
-39
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -164,22 +164,28 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
164164
self._current_max_len = 1024
165165
pos_index = torch.arange(self._current_max_len)
166166
neg_index = torch.arange(self._current_max_len).flip(0) * -1 - 1
167-
self.register_buffer('pos_freqs', torch.cat(
168-
[
169-
self.rope_params(pos_index, self.axes_dim[0], self.theta),
170-
self.rope_params(pos_index, self.axes_dim[1], self.theta),
171-
self.rope_params(pos_index, self.axes_dim[2], self.theta),
172-
],
173-
dim=1,
174-
))
175-
self.register_buffer('neg_freqs', torch.cat(
176-
[
177-
self.rope_params(neg_index, self.axes_dim[0], self.theta),
178-
self.rope_params(neg_index, self.axes_dim[1], self.theta),
179-
self.rope_params(neg_index, self.axes_dim[2], self.theta),
180-
],
181-
dim=1,
182-
))
167+
self.register_buffer(
168+
"pos_freqs",
169+
torch.cat(
170+
[
171+
self.rope_params(pos_index, self.axes_dim[0], self.theta),
172+
self.rope_params(pos_index, self.axes_dim[1], self.theta),
173+
self.rope_params(pos_index, self.axes_dim[2], self.theta),
174+
],
175+
dim=1,
176+
),
177+
)
178+
self.register_buffer(
179+
"neg_freqs",
180+
torch.cat(
181+
[
182+
self.rope_params(neg_index, self.axes_dim[0], self.theta),
183+
self.rope_params(neg_index, self.axes_dim[1], self.theta),
184+
self.rope_params(neg_index, self.axes_dim[2], self.theta),
185+
],
186+
dim=1,
187+
),
188+
)
183189
self.rope_cache = {}
184190

185191
# 是否使用 scale rope
@@ -199,22 +205,22 @@ def _expand_pos_freqs_if_needed(self, required_len):
199205
"""Expand pos_freqs and neg_freqs if required length exceeds current size"""
200206
if required_len <= self._current_max_len:
201207
return
202-
208+
203209
# Calculate new size (use next power of 2 or round to nearest 512 for efficiency)
204210
new_max_len = max(required_len, int((required_len + 511) // 512) * 512)
205-
211+
206212
# Log warning about potential quality degradation for long prompts
207213
if required_len > 512:
208214
logger.warning(
209215
f"QwenImage model was trained on prompts up to 512 tokens. "
210216
f"Current prompt requires {required_len} tokens, which may lead to unpredictable behavior. "
211217
f"Consider using shorter prompts for better results."
212218
)
213-
219+
214220
# Generate expanded indices
215221
pos_index = torch.arange(new_max_len, device=self.pos_freqs.device)
216222
neg_index = torch.arange(new_max_len, device=self.neg_freqs.device).flip(0) * -1 - 1
217-
223+
218224
# Generate expanded frequency embeddings
219225
new_pos_freqs = torch.cat(
220226
[
@@ -224,7 +230,7 @@ def _expand_pos_freqs_if_needed(self, required_len):
224230
],
225231
dim=1,
226232
).to(device=self.pos_freqs.device, dtype=self.pos_freqs.dtype)
227-
233+
228234
new_neg_freqs = torch.cat(
229235
[
230236
self.rope_params(neg_index, self.axes_dim[0], self.theta),
@@ -233,12 +239,12 @@ def _expand_pos_freqs_if_needed(self, required_len):
233239
],
234240
dim=1,
235241
).to(device=self.neg_freqs.device, dtype=self.neg_freqs.dtype)
236-
242+
237243
# Update buffers
238-
self.register_buffer('pos_freqs', new_pos_freqs)
239-
self.register_buffer('neg_freqs', new_neg_freqs)
244+
self.register_buffer("pos_freqs", new_pos_freqs)
245+
self.register_buffer("neg_freqs", new_neg_freqs)
240246
self._current_max_len = new_max_len
241-
247+
242248
# Clear cache since dimensions changed
243249
self.rope_cache = {}
244250

@@ -281,11 +287,11 @@ def forward(self, video_fhw, txt_seq_lens, device):
281287
max_vid_index = max(height, width)
282288

283289
max_len = max(txt_seq_lens)
284-
290+
285291
# Expand pos_freqs if needed to accommodate max_vid_index + max_len
286292
required_len = max_vid_index + max_len
287293
self._expand_pos_freqs_if_needed(required_len)
288-
294+
289295
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
290296

291297
return vid_freqs, txt_freqs

tests/pipelines/qwenimage/test_qwenimage.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -241,43 +241,43 @@ def test_long_prompt_no_error(self):
241241
components = self.get_dummy_components()
242242
pipe = self.pipeline_class(**components)
243243
pipe.to(device)
244-
244+
245245
# Create a very long prompt that exceeds 1024 tokens when combined with image positioning
246246
# Repeat a long phrase to simulate a real long prompt scenario
247247
long_phrase = "A beautiful, detailed, high-resolution, photorealistic image showing "
248248
long_prompt = (long_phrase * 50)[:1200] # Ensure we exceed 1024 characters
249-
249+
250250
inputs = {
251251
"prompt": long_prompt,
252252
"generator": torch.Generator(device=device).manual_seed(0),
253253
"num_inference_steps": 2,
254254
"guidance_scale": 3.0,
255255
"true_cfg_scale": 1.0,
256256
"height": 32, # Small size for fast test
257-
"width": 32, # Small size for fast test
257+
"width": 32, # Small size for fast test
258258
"max_sequence_length": 1200, # Allow long sequence
259259
"output_type": "pt",
260260
}
261-
261+
262262
# This should not raise a RuntimeError about tensor dimension mismatch
263263
_ = pipe(**inputs)
264264

265265
def test_long_prompt_warning(self):
266266
"""Test that long prompts trigger appropriate warning about training limitation"""
267267
from diffusers.utils import logging
268-
268+
269269
components = self.get_dummy_components()
270270
pipe = self.pipeline_class(**components)
271271
pipe.to(torch_device)
272-
272+
273273
# Create prompt that will exceed 512 tokens to trigger warning
274274
long_phrase = "A detailed photorealistic description of a complex scene with many elements "
275275
long_prompt = (long_phrase * 20)[:800] # Create a prompt that will exceed 512 tokens
276-
277-
# Capture transformer logging
276+
277+
# Capture transformer logging
278278
logger = logging.get_logger("diffusers.models.transformers.transformer_qwenimage")
279279
logger.setLevel(logging.WARNING)
280-
280+
281281
with CaptureLogger(logger) as cap_logger:
282282
_ = pipe(
283283
prompt=long_prompt,
@@ -286,11 +286,11 @@ def test_long_prompt_warning(self):
286286
guidance_scale=3.0,
287287
true_cfg_scale=1.0,
288288
height=32, # Small size for fast test
289-
width=32, # Small size for fast test
289+
width=32, # Small size for fast test
290290
max_sequence_length=900, # Allow long sequence
291-
output_type="pt"
291+
output_type="pt",
292292
)
293-
293+
294294
# Verify warning was logged about the 512-token training limitation
295295
self.assertTrue("512 tokens" in cap_logger.out)
296296
self.assertTrue("unpredictable behavior" in cap_logger.out)

0 commit comments

Comments
 (0)