Skip to content

Commit b8a4cba

Browse files
[qwen-image] edit 2511 support (#12839)
* [qwen-image] edit 2511 support * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 17c0e79 commit b8a4cba

File tree

1 file changed

+68
-5
lines changed

1 file changed

+68
-5
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import functools
1616
import math
17+
from math import prod
1718
from typing import Any, Dict, List, Optional, Tuple, Union
1819

1920
import numpy as np
@@ -363,7 +364,13 @@ def __call__(
363364
@maybe_allow_in_graph
364365
class QwenImageTransformerBlock(nn.Module):
365366
def __init__(
366-
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
367+
self,
368+
dim: int,
369+
num_attention_heads: int,
370+
attention_head_dim: int,
371+
qk_norm: str = "rms_norm",
372+
eps: float = 1e-6,
373+
zero_cond_t: bool = False,
367374
):
368375
super().__init__()
369376

@@ -403,10 +410,43 @@ def __init__(
403410
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
404411
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
405412

406-
def _modulate(self, x, mod_params):
413+
self.zero_cond_t = zero_cond_t
414+
415+
def _modulate(self, x, mod_params, index=None):
407416
"""Apply modulation to input tensor"""
417+
# x: b l d, shift: b d, scale: b d, gate: b d
408418
shift, scale, gate = mod_params.chunk(3, dim=-1)
409-
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
419+
420+
if index is not None:
421+
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
422+
# So shift, scale, gate have shape [2*actual_batch, d]
423+
actual_batch = shift.size(0) // 2
424+
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
425+
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
426+
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
427+
428+
# index: [b, l] where b is actual batch size
429+
# Expand to [b, l, 1] to match feature dimension
430+
index_expanded = index.unsqueeze(-1) # [b, l, 1]
431+
432+
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
433+
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
434+
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
435+
scale_0_exp = scale_0.unsqueeze(1)
436+
scale_1_exp = scale_1.unsqueeze(1)
437+
gate_0_exp = gate_0.unsqueeze(1)
438+
gate_1_exp = gate_1.unsqueeze(1)
439+
440+
# Use torch.where to select based on index
441+
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
442+
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
443+
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
444+
else:
445+
shift_result = shift.unsqueeze(1)
446+
scale_result = scale.unsqueeze(1)
447+
gate_result = gate.unsqueeze(1)
448+
449+
return x * (1 + scale_result) + shift_result, gate_result
410450

411451
def forward(
412452
self,
@@ -416,9 +456,13 @@ def forward(
416456
temb: torch.Tensor,
417457
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
418458
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
459+
modulate_index: Optional[List[int]] = None,
419460
) -> Tuple[torch.Tensor, torch.Tensor]:
420461
# Get modulation parameters for both streams
421462
img_mod_params = self.img_mod(temb) # [B, 6*dim]
463+
464+
if self.zero_cond_t:
465+
temb = torch.chunk(temb, 2, dim=0)[0]
422466
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
423467

424468
# Split modulation parameters for norm1 and norm2
@@ -427,7 +471,7 @@ def forward(
427471

428472
# Process image stream - norm1 + modulation
429473
img_normed = self.img_norm1(hidden_states)
430-
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
474+
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1, modulate_index)
431475

432476
# Process text stream - norm1 + modulation
433477
txt_normed = self.txt_norm1(encoder_hidden_states)
@@ -457,7 +501,7 @@ def forward(
457501

458502
# Process image stream - norm2 + MLP
459503
img_normed2 = self.img_norm2(hidden_states)
460-
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
504+
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2, modulate_index)
461505
img_mlp_output = self.img_mlp(img_modulated2)
462506
hidden_states = hidden_states + img_gate2 * img_mlp_output
463507

@@ -533,6 +577,7 @@ def __init__(
533577
joint_attention_dim: int = 3584,
534578
guidance_embeds: bool = False, # TODO: this should probably be removed
535579
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
580+
zero_cond_t: bool = False,
536581
):
537582
super().__init__()
538583
self.out_channels = out_channels or in_channels
@@ -553,6 +598,7 @@ def __init__(
553598
dim=self.inner_dim,
554599
num_attention_heads=num_attention_heads,
555600
attention_head_dim=attention_head_dim,
601+
zero_cond_t=zero_cond_t,
556602
)
557603
for _ in range(num_layers)
558604
]
@@ -562,6 +608,7 @@ def __init__(
562608
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
563609

564610
self.gradient_checkpointing = False
611+
self.zero_cond_t = zero_cond_t
565612

566613
def forward(
567614
self,
@@ -618,6 +665,17 @@ def forward(
618665
hidden_states = self.img_in(hidden_states)
619666

620667
timestep = timestep.to(hidden_states.dtype)
668+
669+
if self.zero_cond_t:
670+
timestep = torch.cat([timestep, timestep * 0], dim=0)
671+
modulate_index = torch.tensor(
672+
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes],
673+
device=timestep.device,
674+
dtype=torch.int,
675+
)
676+
else:
677+
modulate_index = None
678+
621679
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
622680
encoder_hidden_states = self.txt_in(encoder_hidden_states)
623681

@@ -641,6 +699,8 @@ def forward(
641699
encoder_hidden_states_mask,
642700
temb,
643701
image_rotary_emb,
702+
attention_kwargs,
703+
modulate_index,
644704
)
645705

646706
else:
@@ -651,6 +711,7 @@ def forward(
651711
temb=temb,
652712
image_rotary_emb=image_rotary_emb,
653713
joint_attention_kwargs=attention_kwargs,
714+
modulate_index=modulate_index,
654715
)
655716

656717
# controlnet residual
@@ -659,6 +720,8 @@ def forward(
659720
interval_control = int(np.ceil(interval_control))
660721
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
661722

723+
if self.zero_cond_t:
724+
temb = temb.chunk(2, dim=0)[0]
662725
# Use only the image part (hidden_states) from the dual-stream blocks
663726
hidden_states = self.norm_out(hidden_states, temb)
664727
output = self.proj_out(hidden_states)

0 commit comments

Comments
 (0)