Skip to content

Commit 58ea3bf

Browse files
committed
fix(qwen image):
- compatible with torch.compile in new rope setting - fix init import - add prompt truncation in img2img and inpaint pipe - remove unused logic and comment - add copy statement - guard logic for rope video shape tuple
1 parent 74c91a0 commit 58ea3bf

File tree

6 files changed

+58
-110
lines changed

6 files changed

+58
-110
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
15+
import functools
1616
import math
1717
from typing import Any, Dict, List, Optional, Tuple, Union
1818

@@ -204,33 +204,22 @@ def forward(self, video_fhw, txt_seq_lens, device):
204204

205205
if isinstance(video_fhw, list):
206206
video_fhw = video_fhw[0]
207+
if not isinstance(video_fhw, list):
208+
video_fhw = [video_fhw]
207209

208210
vid_freqs = []
209211
max_vid_index = 0
210212
for idx, fhw in enumerate(video_fhw):
211213
frame, height, width = fhw
212214
rope_key = f"{idx}_{height}_{width}"
213215

214-
if rope_key not in self.rope_cache:
215-
seq_lens = frame * height * width
216-
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
217-
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
218-
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
219-
if self.scale_rope:
220-
freqs_height = torch.cat(
221-
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
222-
)
223-
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
224-
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
225-
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
226-
227-
else:
228-
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
229-
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
230-
231-
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
232-
self.rope_cache[rope_key] = freqs.clone().contiguous()
233-
vid_freqs.append(self.rope_cache[rope_key])
216+
if not torch.compiler.is_compiling():
217+
if rope_key not in self.rope_cache:
218+
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
219+
video_freq = self.rope_cache[rope_key]
220+
else:
221+
video_freq = self._compute_video_freqs(frame, height, width, idx)
222+
vid_freqs.append(video_freq)
234223

235224
if self.scale_rope:
236225
max_vid_index = max(height // 2, width // 2, max_vid_index)
@@ -243,6 +232,25 @@ def forward(self, video_fhw, txt_seq_lens, device):
243232

244233
return vid_freqs, txt_freqs
245234

235+
@functools.lru_cache(maxsize=None)
236+
def _compute_video_freqs(self, frame, height, width, idx=0):
237+
seq_lens = frame * height * width
238+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
239+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
240+
241+
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
242+
if self.scale_rope:
243+
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
244+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
245+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
246+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
247+
else:
248+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
249+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
250+
251+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
252+
return freqs.clone().contiguous()
253+
246254

247255
class QwenDoubleStreamAttnProcessor2_0:
248256
"""

src/diffusers/pipelines/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,12 @@
709709
from .paint_by_example import PaintByExamplePipeline
710710
from .pia import PIAPipeline
711711
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
712-
from .qwenimage import QwenImageEditPipeline, QwenImageInpaintPipeline, QwenImagePipeline
712+
from .qwenimage import (
713+
QwenImageEditPipeline,
714+
QwenImageImg2ImgPipeline,
715+
QwenImageInpaintPipeline,
716+
QwenImagePipeline,
717+
)
713718
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
714719
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
715720
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline

src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -319,20 +319,6 @@ def check_inputs(
319319
if max_sequence_length is not None and max_sequence_length > 1024:
320320
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
321321

322-
@staticmethod
323-
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
324-
latent_image_ids = torch.zeros(height, width, 3)
325-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
326-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
327-
328-
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
329-
330-
latent_image_ids = latent_image_ids.reshape(
331-
latent_image_id_height * latent_image_id_width, latent_image_id_channels
332-
)
333-
334-
return latent_image_ids.to(device=device, dtype=dtype)
335-
336322
@staticmethod
337323
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
338324
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
@@ -405,8 +391,7 @@ def prepare_latents(
405391
shape = (batch_size, 1, num_channels_latents, height, width)
406392

407393
if latents is not None:
408-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
409-
return latents.to(device=device, dtype=dtype), latent_image_ids
394+
return latents.to(device=device, dtype=dtype)
410395

411396
if isinstance(generator, list) and len(generator) != batch_size:
412397
raise ValueError(
@@ -417,9 +402,7 @@ def prepare_latents(
417402
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
418403
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
419404

420-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
421-
422-
return latents, latent_image_ids
405+
return latents
423406

424407
@property
425408
def guidance_scale(self):
@@ -597,7 +580,7 @@ def __call__(
597580

598581
# 4. Prepare latent variables
599582
num_channels_latents = self.transformer.config.in_channels // 4
600-
latents, latent_image_ids = self.prepare_latents(
583+
latents = self.prepare_latents(
601584
batch_size * num_images_per_prompt,
602585
num_channels_latents,
603586
height,

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
]
7777

7878

79+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
7980
def calculate_shift(
8081
image_seq_len,
8182
base_seq_len: int = 256,
@@ -221,12 +222,12 @@ def __init__(
221222
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
222223
self.vl_processor = processor
223224
self.tokenizer_max_length = 1024
224-
# self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
225-
# self.prompt_template_encode_start_idx = 34
225+
226226
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
227227
self.prompt_template_encode_start_idx = 64
228228
self.default_sample_size = 128
229229

230+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
230231
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
231232
bool_mask = mask.bool()
232233
valid_lengths = bool_mask.sum(dim=1)
@@ -379,20 +380,7 @@ def check_inputs(
379380
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
380381

381382
@staticmethod
382-
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
383-
latent_image_ids = torch.zeros(height, width, 3)
384-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
385-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
386-
387-
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
388-
389-
latent_image_ids = latent_image_ids.reshape(
390-
latent_image_id_height * latent_image_id_width, latent_image_id_channels
391-
)
392-
393-
return latent_image_ids.to(device=device, dtype=dtype)
394-
395-
@staticmethod
383+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
396384
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
397385
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
398386
latents = latents.permute(0, 2, 4, 1, 3, 5)
@@ -401,6 +389,7 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
401389
return latents
402390

403391
@staticmethod
392+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
404393
def _unpack_latents(latents, height, width, vae_scale_factor):
405394
batch_size, num_patches, channels = latents.shape
406395

@@ -487,7 +476,7 @@ def prepare_latents(
487476

488477
shape = (batch_size, 1, num_channels_latents, height, width)
489478

490-
image_latents = image_ids = None
479+
image_latents = None
491480
if image is not None:
492481
image = image.to(device=device, dtype=dtype)
493482
if image.shape[1] != self.latent_channels:
@@ -509,13 +498,6 @@ def prepare_latents(
509498
image_latents = self._pack_latents(
510499
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
511500
)
512-
image_ids = self._prepare_latent_image_ids(
513-
batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
514-
)
515-
# image ids are the same as latent ids with the first dimension set to 1 instead of 0
516-
image_ids[..., 0] = 1
517-
518-
latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
519501

520502
if isinstance(generator, list) and len(generator) != batch_size:
521503
raise ValueError(
@@ -528,7 +510,7 @@ def prepare_latents(
528510
else:
529511
latents = latents.to(device=device, dtype=dtype)
530512

531-
return latents, image_latents, latent_ids, image_ids
513+
return latents, image_latents
532514

533515
@property
534516
def guidance_scale(self):
@@ -732,7 +714,7 @@ def __call__(
732714

733715
# 4. Prepare latent variables
734716
num_channels_latents = self.transformer.config.in_channels // 4
735-
latents, image_latents, latent_ids, image_ids = self.prepare_latents(
717+
latents, image_latents = self.prepare_latents(
736718
image,
737719
batch_size * num_images_per_prompt,
738720
num_channels_latents,

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,9 @@ def encode_prompt(
296296
if prompt_embeds is None:
297297
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
298298

299+
prompt_embeds = prompt_embeds[:, :max_sequence_length]
300+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
301+
299302
_, seq_len, _ = prompt_embeds.shape
300303
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
301304
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -363,21 +366,6 @@ def check_inputs(
363366
if max_sequence_length is not None and max_sequence_length > 1024:
364367
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
365368

366-
@staticmethod
367-
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._prepare_latent_image_ids
368-
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
369-
latent_image_ids = torch.zeros(height, width, 3)
370-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
371-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
372-
373-
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
374-
375-
latent_image_ids = latent_image_ids.reshape(
376-
latent_image_id_height * latent_image_id_width, latent_image_id_channels
377-
)
378-
379-
return latent_image_ids.to(device=device, dtype=dtype)
380-
381369
@staticmethod
382370
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
383371
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
@@ -465,8 +453,7 @@ def prepare_latents(
465453
raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
466454

467455
if latents is not None:
468-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
469-
return latents.to(device=device, dtype=dtype), latent_image_ids
456+
return latents.to(device=device, dtype=dtype)
470457

471458
image = image.to(device=device, dtype=dtype)
472459
if image.shape[1] != self.latent_channels:
@@ -489,9 +476,7 @@ def prepare_latents(
489476
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
490477
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
491478

492-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
493-
494-
return latents, latent_image_ids
479+
return latents
495480

496481
@property
497482
def guidance_scale(self):
@@ -713,7 +698,7 @@ def __call__(
713698

714699
# 5. Prepare latent variables
715700
num_channels_latents = self.transformer.config.in_channels // 4
716-
latents, latent_image_ids = self.prepare_latents(
701+
latents = self.prepare_latents(
717702
init_image,
718703
latent_timestep,
719704
batch_size * num_images_per_prompt,

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,9 @@ def encode_prompt(
307307
if prompt_embeds is None:
308308
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
309309

310+
prompt_embeds = prompt_embeds[:, :max_sequence_length]
311+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
312+
310313
_, seq_len, _ = prompt_embeds.shape
311314
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
312315
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -390,21 +393,6 @@ def check_inputs(
390393
if max_sequence_length is not None and max_sequence_length > 1024:
391394
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
392395

393-
@staticmethod
394-
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._prepare_latent_image_ids
395-
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
396-
latent_image_ids = torch.zeros(height, width, 3)
397-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
398-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
399-
400-
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
401-
402-
latent_image_ids = latent_image_ids.reshape(
403-
latent_image_id_height * latent_image_id_width, latent_image_id_channels
404-
)
405-
406-
return latent_image_ids.to(device=device, dtype=dtype)
407-
408396
@staticmethod
409397
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
410398
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
@@ -492,8 +480,7 @@ def prepare_latents(
492480
raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
493481

494482
if latents is not None:
495-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
496-
return latents.to(device=device, dtype=dtype), latent_image_ids
483+
return latents.to(device=device, dtype=dtype)
497484

498485
image = image.to(device=device, dtype=dtype)
499486
if image.shape[1] != self.latent_channels:
@@ -524,9 +511,7 @@ def prepare_latents(
524511
image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
525512
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
526513

527-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
528-
529-
return latents, noise, image_latents, latent_image_ids
514+
return latents, noise, image_latents
530515

531516
def prepare_mask_latents(
532517
self,
@@ -859,7 +844,7 @@ def __call__(
859844
# 5. Prepare latent variables
860845
num_channels_latents = self.transformer.config.in_channels // 4
861846

862-
latents, noise, image_latents, latent_image_ids = self.prepare_latents(
847+
latents, noise, image_latents = self.prepare_latents(
863848
init_image,
864849
latent_timestep,
865850
batch_size * num_images_per_prompt,

0 commit comments

Comments
 (0)