Skip to content

Commit 6c044b9

Browse files
committed
Improve test patterns for QwenImage long prompt warning
- Fix test_long_prompt_warning to properly trigger the 512-token warning - Replace inefficient wall-of-text approach with elegant hardcoded multiplier - Use precise token counting to ensure required_len > _current_max_len threshold - Add runtime assertion for test robustness and maintainability - Fix max_sequence_length validation error in test_long_prompt_no_error
1 parent 35cb2c8 commit 6c044b9

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

tests/pipelines/qwenimage/test_qwenimage.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def test_long_prompt_no_error(self):
255255
"true_cfg_scale": 1.0,
256256
"height": 32, # Small size for fast test
257257
"width": 32, # Small size for fast test
258-
"max_sequence_length": 1200, # Allow long sequence
258+
"max_sequence_length": 1024, # Allow long sequence (max allowed)
259259
"output_type": "pt",
260260
}
261261

@@ -270,9 +270,17 @@ def test_long_prompt_warning(self):
270270
pipe = self.pipeline_class(**components)
271271
pipe.to(torch_device)
272272

273-
# Create prompt that will exceed 512 tokens to trigger warning
274-
long_phrase = "A detailed photorealistic description of a complex scene with many elements "
275-
long_prompt = (long_phrase * 20)[:800] # Create a prompt that will exceed 512 tokens
273+
# Create a long prompt that will exceed the RoPE expansion threshold
274+
# The warning is triggered when required_len = max(height, width) + text_tokens > _current_max_len
275+
# Since _current_max_len is 1024 and height=width=32, we need > 992 tokens
276+
phrase = "A detailed photorealistic image showing many beautiful elements and complex artistic creative features with intricate designs."
277+
long_prompt = phrase * 58 # Generates ~1045 tokens, ensuring required_len > 1024
278+
279+
# Verify we exceed the threshold (for test robustness)
280+
tokenizer = components["tokenizer"]
281+
token_count = len(tokenizer.encode(long_prompt))
282+
required_len = 32 + token_count # height/width + tokens
283+
self.assertGreater(required_len, 1024, f"Test prompt must exceed threshold (got {required_len})")
276284

277285
# Capture transformer logging
278286
logger = logging.get_logger("diffusers.models.transformers.transformer_qwenimage")
@@ -287,7 +295,7 @@ def test_long_prompt_warning(self):
287295
true_cfg_scale=1.0,
288296
height=32, # Small size for fast test
289297
width=32, # Small size for fast test
290-
max_sequence_length=900, # Allow long sequence
298+
max_sequence_length=1024, # Allow long sequence
291299
output_type="pt",
292300
)
293301

0 commit comments

Comments
 (0)