Skip to content

Commit 51fadec

Browse files
sayakpaulapolinariocbensimon
committed
feat: make QwenImage family fully compilable again.
Co-authored-by: apolinario <[email protected]> Co-authored-by: cbensimon <[email protected]>
1 parent 552c127 commit 51fadec

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ def forward(
557557
attention_kwargs: Optional[Dict[str, Any]] = None,
558558
controlnet_block_samples=None,
559559
return_dict: bool = True,
560+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
560561
) -> Union[torch.Tensor, Transformer2DModelOutput]:
561562
"""
562563
The [`QwenTransformer2DModel`] forward method.
@@ -611,8 +612,8 @@ def forward(
611612
if guidance is None
612613
else self.time_text_embed(timestep, guidance, hidden_states)
613614
)
614-
615-
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
615+
if image_rotary_emb is None:
616+
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
616617

617618
for index_block, block in enumerate(self.transformer_blocks):
618619
if torch.is_grad_enabled() and self.gradient_checkpointing:

src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,10 @@ def __call__(
631631
negative_txt_seq_lens = (
632632
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
633633
)
634+
image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
635+
neg_image_rotary_emb = None
636+
if do_true_cfg:
637+
neg_image_rotary_emb = self.transformer.pos_embed(img_shapes, negative_txt_seq_lens, device=latents.device)
634638

635639
# 6. Denoising loop
636640
self.scheduler.set_begin_index(0)
@@ -649,8 +653,7 @@ def __call__(
649653
guidance=guidance,
650654
encoder_hidden_states_mask=prompt_embeds_mask,
651655
encoder_hidden_states=prompt_embeds,
652-
img_shapes=img_shapes,
653-
txt_seq_lens=txt_seq_lens,
656+
image_rotary_emb=image_rotary_emb,
654657
attention_kwargs=self.attention_kwargs,
655658
return_dict=False,
656659
)[0]
@@ -663,8 +666,7 @@ def __call__(
663666
guidance=guidance,
664667
encoder_hidden_states_mask=negative_prompt_embeds_mask,
665668
encoder_hidden_states=negative_prompt_embeds,
666-
img_shapes=img_shapes,
667-
txt_seq_lens=negative_txt_seq_lens,
669+
image_rotary_emb=neg_image_rotary_emb,
668670
attention_kwargs=self.attention_kwargs,
669671
return_dict=False,
670672
)[0]

0 commit comments

Comments
 (0)