1414
1515import functools
1616import math
17+ from math import prod
1718from typing import Any , Dict , List , Optional , Tuple , Union
1819
1920import numpy as np
@@ -363,7 +364,13 @@ def __call__(
363364@maybe_allow_in_graph
364365class QwenImageTransformerBlock (nn .Module ):
365366 def __init__ (
366- self , dim : int , num_attention_heads : int , attention_head_dim : int , qk_norm : str = "rms_norm" , eps : float = 1e-6
367+ self ,
368+ dim : int ,
369+ num_attention_heads : int ,
370+ attention_head_dim : int ,
371+ qk_norm : str = "rms_norm" ,
372+ eps : float = 1e-6 ,
373+ zero_cond_t : bool = False ,
367374 ):
368375 super ().__init__ ()
369376
@@ -403,10 +410,43 @@ def __init__(
403410 self .txt_norm2 = nn .LayerNorm (dim , elementwise_affine = False , eps = eps )
404411 self .txt_mlp = FeedForward (dim = dim , dim_out = dim , activation_fn = "gelu-approximate" )
405412
406- def _modulate (self , x , mod_params ):
413+ self .zero_cond_t = zero_cond_t
414+
415+ def _modulate (self , x , mod_params , index = None ):
407416 """Apply modulation to input tensor"""
417+ # x: b l d, shift: b d, scale: b d, gate: b d
408418 shift , scale , gate = mod_params .chunk (3 , dim = - 1 )
409- return x * (1 + scale .unsqueeze (1 )) + shift .unsqueeze (1 ), gate .unsqueeze (1 )
419+
420+ if index is not None :
421+ # Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
422+ # So shift, scale, gate have shape [2*actual_batch, d]
423+ actual_batch = shift .size (0 ) // 2
424+ shift_0 , shift_1 = shift [:actual_batch ], shift [actual_batch :] # each: [actual_batch, d]
425+ scale_0 , scale_1 = scale [:actual_batch ], scale [actual_batch :]
426+ gate_0 , gate_1 = gate [:actual_batch ], gate [actual_batch :]
427+
428+ # index: [b, l] where b is actual batch size
429+ # Expand to [b, l, 1] to match feature dimension
430+ index_expanded = index .unsqueeze (- 1 ) # [b, l, 1]
431+
432+ # Expand chunks to [b, 1, d] then broadcast to [b, l, d]
433+ shift_0_exp = shift_0 .unsqueeze (1 ) # [b, 1, d]
434+ shift_1_exp = shift_1 .unsqueeze (1 ) # [b, 1, d]
435+ scale_0_exp = scale_0 .unsqueeze (1 )
436+ scale_1_exp = scale_1 .unsqueeze (1 )
437+ gate_0_exp = gate_0 .unsqueeze (1 )
438+ gate_1_exp = gate_1 .unsqueeze (1 )
439+
440+ # Use torch.where to select based on index
441+ shift_result = torch .where (index_expanded == 0 , shift_0_exp , shift_1_exp )
442+ scale_result = torch .where (index_expanded == 0 , scale_0_exp , scale_1_exp )
443+ gate_result = torch .where (index_expanded == 0 , gate_0_exp , gate_1_exp )
444+ else :
445+ shift_result = shift .unsqueeze (1 )
446+ scale_result = scale .unsqueeze (1 )
447+ gate_result = gate .unsqueeze (1 )
448+
449+ return x * (1 + scale_result ) + shift_result , gate_result
410450
411451 def forward (
412452 self ,
@@ -416,9 +456,13 @@ def forward(
416456 temb : torch .Tensor ,
417457 image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
418458 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
459+ modulate_index : Optional [List [int ]] = None ,
419460 ) -> Tuple [torch .Tensor , torch .Tensor ]:
420461 # Get modulation parameters for both streams
421462 img_mod_params = self .img_mod (temb ) # [B, 6*dim]
463+
464+ if self .zero_cond_t :
465+ temb = torch .chunk (temb , 2 , dim = 0 )[0 ]
422466 txt_mod_params = self .txt_mod (temb ) # [B, 6*dim]
423467
424468 # Split modulation parameters for norm1 and norm2
@@ -427,7 +471,7 @@ def forward(
427471
428472 # Process image stream - norm1 + modulation
429473 img_normed = self .img_norm1 (hidden_states )
430- img_modulated , img_gate1 = self ._modulate (img_normed , img_mod1 )
474+ img_modulated , img_gate1 = self ._modulate (img_normed , img_mod1 , modulate_index )
431475
432476 # Process text stream - norm1 + modulation
433477 txt_normed = self .txt_norm1 (encoder_hidden_states )
@@ -457,7 +501,7 @@ def forward(
457501
458502 # Process image stream - norm2 + MLP
459503 img_normed2 = self .img_norm2 (hidden_states )
460- img_modulated2 , img_gate2 = self ._modulate (img_normed2 , img_mod2 )
504+ img_modulated2 , img_gate2 = self ._modulate (img_normed2 , img_mod2 , modulate_index )
461505 img_mlp_output = self .img_mlp (img_modulated2 )
462506 hidden_states = hidden_states + img_gate2 * img_mlp_output
463507
@@ -533,6 +577,7 @@ def __init__(
533577 joint_attention_dim : int = 3584 ,
534578 guidance_embeds : bool = False , # TODO: this should probably be removed
535579 axes_dims_rope : Tuple [int , int , int ] = (16 , 56 , 56 ),
580+ zero_cond_t : bool = False ,
536581 ):
537582 super ().__init__ ()
538583 self .out_channels = out_channels or in_channels
@@ -553,6 +598,7 @@ def __init__(
553598 dim = self .inner_dim ,
554599 num_attention_heads = num_attention_heads ,
555600 attention_head_dim = attention_head_dim ,
601+ zero_cond_t = zero_cond_t ,
556602 )
557603 for _ in range (num_layers )
558604 ]
@@ -562,6 +608,7 @@ def __init__(
562608 self .proj_out = nn .Linear (self .inner_dim , patch_size * patch_size * self .out_channels , bias = True )
563609
564610 self .gradient_checkpointing = False
611+ self .zero_cond_t = zero_cond_t
565612
566613 def forward (
567614 self ,
@@ -618,6 +665,17 @@ def forward(
618665 hidden_states = self .img_in (hidden_states )
619666
620667 timestep = timestep .to (hidden_states .dtype )
668+
669+ if self .zero_cond_t :
670+ timestep = torch .cat ([timestep , timestep * 0 ], dim = 0 )
671+ modulate_index = torch .tensor (
672+ [[0 ] * prod (sample [0 ]) + [1 ] * sum ([prod (s ) for s in sample [1 :]]) for sample in img_shapes ],
673+ device = timestep .device ,
674+ dtype = torch .int ,
675+ )
676+ else :
677+ modulate_index = None
678+
621679 encoder_hidden_states = self .txt_norm (encoder_hidden_states )
622680 encoder_hidden_states = self .txt_in (encoder_hidden_states )
623681
@@ -641,6 +699,8 @@ def forward(
641699 encoder_hidden_states_mask ,
642700 temb ,
643701 image_rotary_emb ,
702+ attention_kwargs ,
703+ modulate_index ,
644704 )
645705
646706 else :
@@ -651,6 +711,7 @@ def forward(
651711 temb = temb ,
652712 image_rotary_emb = image_rotary_emb ,
653713 joint_attention_kwargs = attention_kwargs ,
714+ modulate_index = modulate_index ,
654715 )
655716
656717 # controlnet residual
@@ -659,6 +720,8 @@ def forward(
659720 interval_control = int (np .ceil (interval_control ))
660721 hidden_states = hidden_states + controlnet_block_samples [index_block // interval_control ]
661722
723+ if self .zero_cond_t :
724+ temb = temb .chunk (2 , dim = 0 )[0 ]
662725 # Use only the image part (hidden_states) from the dual-stream blocks
663726 hidden_states = self .norm_out (hidden_states , temb )
664727 output = self .proj_out (hidden_states )
0 commit comments