diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 961ed72b73f5..b008d000bb09 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -160,23 +160,31 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): super().__init__() self.theta = theta self.axes_dim = axes_dim - pos_index = torch.arange(1024) - neg_index = torch.arange(1024).flip(0) * -1 - 1 - self.pos_freqs = torch.cat( - [ - self.rope_params(pos_index, self.axes_dim[0], self.theta), - self.rope_params(pos_index, self.axes_dim[1], self.theta), - self.rope_params(pos_index, self.axes_dim[2], self.theta), - ], - dim=1, + # Initialize with default size 1024, but allow dynamic expansion + self._current_max_len = 1024 + pos_index = torch.arange(self._current_max_len) + neg_index = torch.arange(self._current_max_len).flip(0) * -1 - 1 + self.register_buffer( + "pos_freqs", + torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ), ) - self.neg_freqs = torch.cat( - [ - self.rope_params(neg_index, self.axes_dim[0], self.theta), - self.rope_params(neg_index, self.axes_dim[1], self.theta), - self.rope_params(neg_index, self.axes_dim[2], self.theta), - ], - dim=1, + self.register_buffer( + "neg_freqs", + torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ), ) self.rope_cache = {} @@ -193,6 +201,53 @@ def rope_params(self, index, dim, theta=10000): freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs + def _expand_pos_freqs_if_needed(self, required_len): + """Expand pos_freqs and neg_freqs if required length exceeds current size""" + if required_len <= self._current_max_len: + return + + # Calculate new size (use next power of 2 or round to nearest 512 for efficiency) + new_max_len = max(required_len, int((required_len + 511) // 512) * 512) + + # Log warning about potential quality degradation for long prompts + if required_len > 512: + logger.warning( + f"QwenImage model was trained on prompts up to 512 tokens. " + f"Current prompt requires {required_len} tokens, which may lead to unpredictable behavior. " + f"Consider using shorter prompts for better results." + ) + + # Generate expanded indices + pos_index = torch.arange(new_max_len, device=self.pos_freqs.device) + neg_index = torch.arange(new_max_len, device=self.neg_freqs.device).flip(0) * -1 - 1 + + # Generate expanded frequency embeddings + new_pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ).to(device=self.pos_freqs.device, dtype=self.pos_freqs.dtype) + + new_neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ).to(device=self.neg_freqs.device, dtype=self.neg_freqs.dtype) + + # Update buffers + self.register_buffer("pos_freqs", new_pos_freqs) + self.register_buffer("neg_freqs", new_neg_freqs) + self._current_max_len = new_max_len + + # Clear cache since dimensions changed + self.rope_cache = {} + def forward(self, video_fhw, txt_seq_lens, device): """ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: @@ -232,6 +287,11 @@ def forward(self, video_fhw, txt_seq_lens, device): max_vid_index = max(height, width) max_len = max(txt_seq_lens) + + # Expand pos_freqs if needed to accommodate max_vid_index + max_len + required_len = max_vid_index + max_len + self._expand_pos_freqs_if_needed(required_len) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] return vid_freqs, txt_freqs diff --git a/tests/pipelines/qwenimage/test_qwenimage.py b/tests/pipelines/qwenimage/test_qwenimage.py index a312d0658fea..54ca4b204e9a 100644 --- a/tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/pipelines/qwenimage/test_qwenimage.py @@ -24,7 +24,7 @@ QwenImagePipeline, QwenImageTransformer2DModel, ) -from diffusers.utils.testing_utils import enable_full_determinism, torch_device +from diffusers.utils.testing_utils import CaptureLogger, enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin, to_np @@ -234,3 +234,79 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): expected_diff_max, "VAE tiling should not affect the inference results", ) + + def test_long_prompt_no_error(self): + # Test for issue #12083: long prompts should not cause dimension mismatch errors + device = torch_device + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + + # Create a long prompt that approaches but stays within limits + # This tests the original issue fix without triggering the warning + phrase = "A beautiful, detailed, high-resolution, photorealistic image showing " + long_prompt = phrase * 40 # Generates ~800 tokens, well within limits + + # Verify token count for test clarity + tokenizer = components["tokenizer"] + token_count = len(tokenizer.encode(long_prompt)) + required_len = 32 + token_count # height/width + tokens + # Should be large enough to test the fix but not trigger expansion warning + self.assertGreater(token_count, 500, f"Test prompt should be substantial (got {token_count} tokens)") + self.assertLess(required_len, 1024, f"Test should stay within limits (got {required_len})") + + inputs = { + "prompt": long_prompt, + "generator": torch.Generator(device=device).manual_seed(0), + "num_inference_steps": 2, + "guidance_scale": 3.0, + "true_cfg_scale": 1.0, + "height": 32, # Small size for fast test + "width": 32, # Small size for fast test + "max_sequence_length": 1024, # Allow long sequence (max allowed) + "output_type": "pt", + } + + # This should not raise a RuntimeError about tensor dimension mismatch + _ = pipe(**inputs) + + def test_long_prompt_warning(self): + """Test that long prompts trigger appropriate warning about training limitation""" + from diffusers.utils import logging + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + + # Create a long prompt that will exceed the RoPE expansion threshold + # The warning is triggered when required_len = max(height, width) + text_tokens > _current_max_len + # Since _current_max_len is 1024 and height=width=32, we need > 992 tokens + phrase = "A detailed photorealistic image showing many beautiful elements and complex artistic creative features with intricate designs." + long_prompt = phrase * 58 # Generates ~1045 tokens, ensuring required_len > 1024 + + # Verify we exceed the threshold (for test robustness) + tokenizer = components["tokenizer"] + token_count = len(tokenizer.encode(long_prompt)) + required_len = 32 + token_count # height/width + tokens + self.assertGreater(required_len, 1024, f"Test prompt must exceed threshold (got {required_len})") + + # Capture transformer logging + logger = logging.get_logger("diffusers.models.transformers.transformer_qwenimage") + logger.setLevel(logging.WARNING) + + with CaptureLogger(logger) as cap_logger: + _ = pipe( + prompt=long_prompt, + generator=torch.Generator(device=torch_device).manual_seed(0), + num_inference_steps=2, + guidance_scale=3.0, + true_cfg_scale=1.0, + height=32, # Small size for fast test + width=32, # Small size for fast test + max_sequence_length=1024, # Allow long sequence + output_type="pt", + ) + + # Verify warning was logged about the 512-token training limitation + self.assertTrue("512 tokens" in cap_logger.out) + self.assertTrue("unpredictable behavior" in cap_logger.out)