diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index c0fa031b9faf..3adfcdb1478b 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -14,6 +14,7 @@ import functools import math +from math import prod from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -363,7 +364,13 @@ def __call__( @maybe_allow_in_graph class QwenImageTransformerBlock(nn.Module): def __init__( - self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + zero_cond_t: bool = False, ): super().__init__() @@ -403,10 +410,43 @@ def __init__( self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - def _modulate(self, x, mod_params): + self.zero_cond_t = zero_cond_t + + def _modulate(self, x, mod_params, index=None): """Apply modulation to input tensor""" + # x: b l d, shift: b d, scale: b d, gate: b d shift, scale, gate = mod_params.chunk(3, dim=-1) - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + if index is not None: + # Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts) + # So shift, scale, gate have shape [2*actual_batch, d] + actual_batch = shift.size(0) // 2 + shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d] + scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:] + gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:] + + # index: [b, l] where b is actual batch size + # Expand to [b, l, 1] to match feature dimension + index_expanded = index.unsqueeze(-1) # [b, l, 1] + + # Expand chunks to [b, 1, d] then broadcast to [b, l, d] + shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d] + shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d] + scale_0_exp = scale_0.unsqueeze(1) + scale_1_exp = scale_1.unsqueeze(1) + gate_0_exp = gate_0.unsqueeze(1) + gate_1_exp = gate_1.unsqueeze(1) + + # Use torch.where to select based on index + shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp) + scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp) + gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp) + else: + shift_result = shift.unsqueeze(1) + scale_result = scale.unsqueeze(1) + gate_result = gate.unsqueeze(1) + + return x * (1 + scale_result) + shift_result, gate_result def forward( self, @@ -416,9 +456,13 @@ def forward( temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, + modulate_index: Optional[List[int]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Get modulation parameters for both streams img_mod_params = self.img_mod(temb) # [B, 6*dim] + + if self.zero_cond_t: + temb = torch.chunk(temb, 2, dim=0)[0] txt_mod_params = self.txt_mod(temb) # [B, 6*dim] # Split modulation parameters for norm1 and norm2 @@ -427,7 +471,7 @@ def forward( # Process image stream - norm1 + modulation img_normed = self.img_norm1(hidden_states) - img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1, modulate_index) # Process text stream - norm1 + modulation txt_normed = self.txt_norm1(encoder_hidden_states) @@ -457,7 +501,7 @@ def forward( # Process image stream - norm2 + MLP img_normed2 = self.img_norm2(hidden_states) - img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2, modulate_index) img_mlp_output = self.img_mlp(img_modulated2) hidden_states = hidden_states + img_gate2 * img_mlp_output @@ -533,6 +577,7 @@ def __init__( joint_attention_dim: int = 3584, guidance_embeds: bool = False, # TODO: this should probably be removed axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + zero_cond_t: bool = False, ): super().__init__() self.out_channels = out_channels or in_channels @@ -553,6 +598,7 @@ def __init__( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, + zero_cond_t=zero_cond_t, ) for _ in range(num_layers) ] @@ -562,6 +608,7 @@ def __init__( self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False + self.zero_cond_t = zero_cond_t def forward( self, @@ -618,6 +665,17 @@ def forward( hidden_states = self.img_in(hidden_states) timestep = timestep.to(hidden_states.dtype) + + if self.zero_cond_t: + timestep = torch.cat([timestep, timestep * 0], dim=0) + modulate_index = torch.tensor( + [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes], + device=timestep.device, + dtype=torch.int, + ) + else: + modulate_index = None + encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) @@ -641,6 +699,8 @@ def forward( encoder_hidden_states_mask, temb, image_rotary_emb, + attention_kwargs, + modulate_index, ) else: @@ -651,6 +711,7 @@ def forward( temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=attention_kwargs, + modulate_index=modulate_index, ) # controlnet residual @@ -659,6 +720,8 @@ def forward( interval_control = int(np.ceil(interval_control)) hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + if self.zero_cond_t: + temb = temb.chunk(2, dim=0)[0] # Use only the image part (hidden_states) from the dual-stream blocks hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states)