Skip to content

Commit 36f569a

Browse files
committed
[qwen-image] edit 2511 support
1 parent 17c0e79 commit 36f569a

File tree

1 file changed

+59
-6
lines changed

1 file changed

+59
-6
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import functools
1616
import math
17+
from math import prod
1718
from typing import Any, Dict, List, Optional, Tuple, Union
1819

1920
import numpy as np
@@ -363,7 +364,7 @@ def __call__(
363364
@maybe_allow_in_graph
364365
class QwenImageTransformerBlock(nn.Module):
365366
def __init__(
366-
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
367+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6, zero_cond_t: bool = False
367368
):
368369
super().__init__()
369370

@@ -403,10 +404,43 @@ def __init__(
403404
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
404405
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
405406

406-
def _modulate(self, x, mod_params):
407+
self.zero_cond_t = zero_cond_t
408+
409+
def _modulate(self, x, mod_params, index=None):
407410
"""Apply modulation to input tensor"""
411+
# x: b l d, shift: b d, scale: b d, gate: b d
408412
shift, scale, gate = mod_params.chunk(3, dim=-1)
409-
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
413+
414+
if index is not None:
415+
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
416+
# So shift, scale, gate have shape [2*actual_batch, d]
417+
actual_batch = shift.size(0) // 2
418+
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
419+
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
420+
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
421+
422+
# index: [b, l] where b is actual batch size
423+
# Expand to [b, l, 1] to match feature dimension
424+
index_expanded = index.unsqueeze(-1) # [b, l, 1]
425+
426+
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
427+
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
428+
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
429+
scale_0_exp = scale_0.unsqueeze(1)
430+
scale_1_exp = scale_1.unsqueeze(1)
431+
gate_0_exp = gate_0.unsqueeze(1)
432+
gate_1_exp = gate_1.unsqueeze(1)
433+
434+
# Use torch.where to select based on index
435+
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
436+
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
437+
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
438+
else:
439+
shift_result = shift.unsqueeze(1)
440+
scale_result = scale.unsqueeze(1)
441+
gate_result = gate.unsqueeze(1)
442+
443+
return x * (1 + scale_result) + shift_result, gate_result
410444

411445
def forward(
412446
self,
@@ -416,9 +450,13 @@ def forward(
416450
temb: torch.Tensor,
417451
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
418452
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
453+
modulate_index: Optional[List[int]] = None,
419454
) -> Tuple[torch.Tensor, torch.Tensor]:
420455
# Get modulation parameters for both streams
421456
img_mod_params = self.img_mod(temb) # [B, 6*dim]
457+
458+
if self.zero_cond_t:
459+
temb = torch.chunk(temb, 2, dim=0)[0]
422460
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
423461

424462
# Split modulation parameters for norm1 and norm2
@@ -427,7 +465,7 @@ def forward(
427465

428466
# Process image stream - norm1 + modulation
429467
img_normed = self.img_norm1(hidden_states)
430-
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
468+
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1, modulate_index)
431469

432470
# Process text stream - norm1 + modulation
433471
txt_normed = self.txt_norm1(encoder_hidden_states)
@@ -457,7 +495,7 @@ def forward(
457495

458496
# Process image stream - norm2 + MLP
459497
img_normed2 = self.img_norm2(hidden_states)
460-
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
498+
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2, modulate_index)
461499
img_mlp_output = self.img_mlp(img_modulated2)
462500
hidden_states = hidden_states + img_gate2 * img_mlp_output
463501

@@ -533,6 +571,7 @@ def __init__(
533571
joint_attention_dim: int = 3584,
534572
guidance_embeds: bool = False, # TODO: this should probably be removed
535573
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
574+
zero_cond_t: bool=False
536575
):
537576
super().__init__()
538577
self.out_channels = out_channels or in_channels
@@ -553,6 +592,7 @@ def __init__(
553592
dim=self.inner_dim,
554593
num_attention_heads=num_attention_heads,
555594
attention_head_dim=attention_head_dim,
595+
zero_cond_t=zero_cond_t
556596
)
557597
for _ in range(num_layers)
558598
]
@@ -562,6 +602,7 @@ def __init__(
562602
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
563603

564604
self.gradient_checkpointing = False
605+
self.zero_cond_t = zero_cond_t
565606

566607
def forward(
567608
self,
@@ -618,6 +659,13 @@ def forward(
618659
hidden_states = self.img_in(hidden_states)
619660

620661
timestep = timestep.to(hidden_states.dtype)
662+
663+
if self.zero_cond_t:
664+
timestep = torch.cat([timestep, timestep * 0], dim=0)
665+
modulate_index = torch.tensor([[0]* prod(sample[0]) + [1]* sum([prod(s) for s in sample[1:]]) for sample in img_shapes], device=timestep.device, dtype=torch.int)
666+
else:
667+
modulate_index = None
668+
621669
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
622670
encoder_hidden_states = self.txt_in(encoder_hidden_states)
623671

@@ -641,6 +689,8 @@ def forward(
641689
encoder_hidden_states_mask,
642690
temb,
643691
image_rotary_emb,
692+
attention_kwargs,
693+
modulate_index,
644694
)
645695

646696
else:
@@ -651,6 +701,7 @@ def forward(
651701
temb=temb,
652702
image_rotary_emb=image_rotary_emb,
653703
joint_attention_kwargs=attention_kwargs,
704+
modulate_index=modulate_index,
654705
)
655706

656707
# controlnet residual
@@ -659,6 +710,8 @@ def forward(
659710
interval_control = int(np.ceil(interval_control))
660711
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
661712

713+
if self.zero_cond_t:
714+
temb = temb.chunk(2, dim=0)[0]
662715
# Use only the image part (hidden_states) from the dual-stream blocks
663716
hidden_states = self.norm_out(hidden_states, temb)
664717
output = self.proj_out(hidden_states)
@@ -670,4 +723,4 @@ def forward(
670723
if not return_dict:
671724
return (output,)
672725

673-
return Transformer2DModelOutput(sample=output)
726+
return Transformer2DModelOutput(sample=output)

0 commit comments

Comments
 (0)