Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import functools
import math
from math import prod
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -363,7 +364,7 @@ def __call__(
@maybe_allow_in_graph
class QwenImageTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6, zero_cond_t: bool = False
):
super().__init__()

Expand Down Expand Up @@ -403,10 +404,43 @@ def __init__(
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

def _modulate(self, x, mod_params):
self.zero_cond_t = zero_cond_t

def _modulate(self, x, mod_params, index=None):
"""Apply modulation to input tensor"""
# x: b l d, shift: b d, scale: b d, gate: b d
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)

if index is not None:
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
# So shift, scale, gate have shape [2*actual_batch, d]
actual_batch = shift.size(0) // 2
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]

# index: [b, l] where b is actual batch size
# Expand to [b, l, 1] to match feature dimension
index_expanded = index.unsqueeze(-1) # [b, l, 1]

# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
scale_0_exp = scale_0.unsqueeze(1)
scale_1_exp = scale_1.unsqueeze(1)
gate_0_exp = gate_0.unsqueeze(1)
gate_1_exp = gate_1.unsqueeze(1)

# Use torch.where to select based on index
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
else:
shift_result = shift.unsqueeze(1)
scale_result = scale.unsqueeze(1)
gate_result = gate.unsqueeze(1)

return x * (1 + scale_result) + shift_result, gate_result

def forward(
self,
Expand All @@ -416,9 +450,13 @@ def forward(
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
modulate_index: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Get modulation parameters for both streams
img_mod_params = self.img_mod(temb) # [B, 6*dim]

if self.zero_cond_t:
temb = torch.chunk(temb, 2, dim=0)[0]
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]

# Split modulation parameters for norm1 and norm2
Expand All @@ -427,7 +465,7 @@ def forward(

# Process image stream - norm1 + modulation
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1, modulate_index)

# Process text stream - norm1 + modulation
txt_normed = self.txt_norm1(encoder_hidden_states)
Expand Down Expand Up @@ -457,7 +495,7 @@ def forward(

# Process image stream - norm2 + MLP
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2, modulate_index)
img_mlp_output = self.img_mlp(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output

Expand Down Expand Up @@ -533,6 +571,7 @@ def __init__(
joint_attention_dim: int = 3584,
guidance_embeds: bool = False, # TODO: this should probably be removed
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
zero_cond_t: bool=False
):
super().__init__()
self.out_channels = out_channels or in_channels
Expand All @@ -553,6 +592,7 @@ def __init__(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
zero_cond_t=zero_cond_t
)
for _ in range(num_layers)
]
Expand All @@ -562,6 +602,7 @@ def __init__(
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)

self.gradient_checkpointing = False
self.zero_cond_t = zero_cond_t

def forward(
self,
Expand Down Expand Up @@ -618,6 +659,13 @@ def forward(
hidden_states = self.img_in(hidden_states)

timestep = timestep.to(hidden_states.dtype)

if self.zero_cond_t:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally optional. But it would be great to leave a comment, briefly explaining what this is doing.

timestep = torch.cat([timestep, timestep * 0], dim=0)
modulate_index = torch.tensor([[0]* prod(sample[0]) + [1]* sum([prod(s) for s in sample[1:]]) for sample in img_shapes], device=timestep.device, dtype=torch.int)
else:
modulate_index = None

encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)

Expand All @@ -641,6 +689,8 @@ def forward(
encoder_hidden_states_mask,
temb,
image_rotary_emb,
attention_kwargs,
modulate_index,
)

else:
Expand All @@ -651,6 +701,7 @@ def forward(
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
modulate_index=modulate_index,
)

# controlnet residual
Expand All @@ -659,6 +710,8 @@ def forward(
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]

if self.zero_cond_t:
temb = temb.chunk(2, dim=0)[0]
# Use only the image part (hidden_states) from the dual-stream blocks
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
Expand All @@ -670,4 +723,4 @@ def forward(
if not return_dict:
return (output,)

return Transformer2DModelOutput(sample=output)
return Transformer2DModelOutput(sample=output)
Loading