1414
1515import functools
1616import math
17+ from math import prod
1718from typing import Any , Dict , List , Optional , Tuple , Union
1819
1920import numpy as np
@@ -363,7 +364,7 @@ def __call__(
363364@maybe_allow_in_graph
364365class QwenImageTransformerBlock (nn .Module ):
365366 def __init__ (
366- self , dim : int , num_attention_heads : int , attention_head_dim : int , qk_norm : str = "rms_norm" , eps : float = 1e-6
367+ self , dim : int , num_attention_heads : int , attention_head_dim : int , qk_norm : str = "rms_norm" , eps : float = 1e-6 , zero_cond_t : bool = False
367368 ):
368369 super ().__init__ ()
369370
@@ -403,10 +404,43 @@ def __init__(
403404 self .txt_norm2 = nn .LayerNorm (dim , elementwise_affine = False , eps = eps )
404405 self .txt_mlp = FeedForward (dim = dim , dim_out = dim , activation_fn = "gelu-approximate" )
405406
406- def _modulate (self , x , mod_params ):
407+ self .zero_cond_t = zero_cond_t
408+
409+ def _modulate (self , x , mod_params , index = None ):
407410 """Apply modulation to input tensor"""
411+ # x: b l d, shift: b d, scale: b d, gate: b d
408412 shift , scale , gate = mod_params .chunk (3 , dim = - 1 )
409- return x * (1 + scale .unsqueeze (1 )) + shift .unsqueeze (1 ), gate .unsqueeze (1 )
413+
414+ if index is not None :
415+ # Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
416+ # So shift, scale, gate have shape [2*actual_batch, d]
417+ actual_batch = shift .size (0 ) // 2
418+ shift_0 , shift_1 = shift [:actual_batch ], shift [actual_batch :] # each: [actual_batch, d]
419+ scale_0 , scale_1 = scale [:actual_batch ], scale [actual_batch :]
420+ gate_0 , gate_1 = gate [:actual_batch ], gate [actual_batch :]
421+
422+ # index: [b, l] where b is actual batch size
423+ # Expand to [b, l, 1] to match feature dimension
424+ index_expanded = index .unsqueeze (- 1 ) # [b, l, 1]
425+
426+ # Expand chunks to [b, 1, d] then broadcast to [b, l, d]
427+ shift_0_exp = shift_0 .unsqueeze (1 ) # [b, 1, d]
428+ shift_1_exp = shift_1 .unsqueeze (1 ) # [b, 1, d]
429+ scale_0_exp = scale_0 .unsqueeze (1 )
430+ scale_1_exp = scale_1 .unsqueeze (1 )
431+ gate_0_exp = gate_0 .unsqueeze (1 )
432+ gate_1_exp = gate_1 .unsqueeze (1 )
433+
434+ # Use torch.where to select based on index
435+ shift_result = torch .where (index_expanded == 0 , shift_0_exp , shift_1_exp )
436+ scale_result = torch .where (index_expanded == 0 , scale_0_exp , scale_1_exp )
437+ gate_result = torch .where (index_expanded == 0 , gate_0_exp , gate_1_exp )
438+ else :
439+ shift_result = shift .unsqueeze (1 )
440+ scale_result = scale .unsqueeze (1 )
441+ gate_result = gate .unsqueeze (1 )
442+
443+ return x * (1 + scale_result ) + shift_result , gate_result
410444
411445 def forward (
412446 self ,
@@ -416,9 +450,13 @@ def forward(
416450 temb : torch .Tensor ,
417451 image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
418452 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
453+ modulate_index : Optional [List [int ]] = None ,
419454 ) -> Tuple [torch .Tensor , torch .Tensor ]:
420455 # Get modulation parameters for both streams
421456 img_mod_params = self .img_mod (temb ) # [B, 6*dim]
457+
458+ if self .zero_cond_t :
459+ temb = torch .chunk (temb , 2 , dim = 0 )[0 ]
422460 txt_mod_params = self .txt_mod (temb ) # [B, 6*dim]
423461
424462 # Split modulation parameters for norm1 and norm2
@@ -427,7 +465,7 @@ def forward(
427465
428466 # Process image stream - norm1 + modulation
429467 img_normed = self .img_norm1 (hidden_states )
430- img_modulated , img_gate1 = self ._modulate (img_normed , img_mod1 )
468+ img_modulated , img_gate1 = self ._modulate (img_normed , img_mod1 , modulate_index )
431469
432470 # Process text stream - norm1 + modulation
433471 txt_normed = self .txt_norm1 (encoder_hidden_states )
@@ -457,7 +495,7 @@ def forward(
457495
458496 # Process image stream - norm2 + MLP
459497 img_normed2 = self .img_norm2 (hidden_states )
460- img_modulated2 , img_gate2 = self ._modulate (img_normed2 , img_mod2 )
498+ img_modulated2 , img_gate2 = self ._modulate (img_normed2 , img_mod2 , modulate_index )
461499 img_mlp_output = self .img_mlp (img_modulated2 )
462500 hidden_states = hidden_states + img_gate2 * img_mlp_output
463501
@@ -533,6 +571,7 @@ def __init__(
533571 joint_attention_dim : int = 3584 ,
534572 guidance_embeds : bool = False , # TODO: this should probably be removed
535573 axes_dims_rope : Tuple [int , int , int ] = (16 , 56 , 56 ),
574+ zero_cond_t : bool = False
536575 ):
537576 super ().__init__ ()
538577 self .out_channels = out_channels or in_channels
@@ -553,6 +592,7 @@ def __init__(
553592 dim = self .inner_dim ,
554593 num_attention_heads = num_attention_heads ,
555594 attention_head_dim = attention_head_dim ,
595+ zero_cond_t = zero_cond_t
556596 )
557597 for _ in range (num_layers )
558598 ]
@@ -562,6 +602,7 @@ def __init__(
562602 self .proj_out = nn .Linear (self .inner_dim , patch_size * patch_size * self .out_channels , bias = True )
563603
564604 self .gradient_checkpointing = False
605+ self .zero_cond_t = zero_cond_t
565606
566607 def forward (
567608 self ,
@@ -618,6 +659,13 @@ def forward(
618659 hidden_states = self .img_in (hidden_states )
619660
620661 timestep = timestep .to (hidden_states .dtype )
662+
663+ if self .zero_cond_t :
664+ timestep = torch .cat ([timestep , timestep * 0 ], dim = 0 )
665+ modulate_index = torch .tensor ([[0 ]* prod (sample [0 ]) + [1 ]* sum ([prod (s ) for s in sample [1 :]]) for sample in img_shapes ], device = timestep .device , dtype = torch .int )
666+ else :
667+ modulate_index = None
668+
621669 encoder_hidden_states = self .txt_norm (encoder_hidden_states )
622670 encoder_hidden_states = self .txt_in (encoder_hidden_states )
623671
@@ -641,6 +689,8 @@ def forward(
641689 encoder_hidden_states_mask ,
642690 temb ,
643691 image_rotary_emb ,
692+ attention_kwargs ,
693+ modulate_index ,
644694 )
645695
646696 else :
@@ -651,6 +701,7 @@ def forward(
651701 temb = temb ,
652702 image_rotary_emb = image_rotary_emb ,
653703 joint_attention_kwargs = attention_kwargs ,
704+ modulate_index = modulate_index ,
654705 )
655706
656707 # controlnet residual
@@ -659,6 +710,8 @@ def forward(
659710 interval_control = int (np .ceil (interval_control ))
660711 hidden_states = hidden_states + controlnet_block_samples [index_block // interval_control ]
661712
713+ if self .zero_cond_t :
714+ temb = temb .chunk (2 , dim = 0 )[0 ]
662715 # Use only the image part (hidden_states) from the dual-stream blocks
663716 hidden_states = self .norm_out (hidden_states , temb )
664717 output = self .proj_out (hidden_states )
@@ -670,4 +723,4 @@ def forward(
670723 if not return_dict :
671724 return (output ,)
672725
673- return Transformer2DModelOutput (sample = output )
726+ return Transformer2DModelOutput (sample = output )
0 commit comments