22import torch .nn as nn
33from typing import Any , Dict , List , Tuple , Union , Optional
44from einops import rearrange
5+ from math import prod
56
67from diffsynth_engine .models .base import StateDictConverter , PreTrainedModel
78from diffsynth_engine .models .basic import attention as attention_ops
@@ -243,6 +244,7 @@ def __init__(
243244 num_attention_heads : int ,
244245 attention_head_dim : int ,
245246 eps : float = 1e-6 ,
247+ zero_cond_t : bool = False ,
246248 device : str = "cuda:0" ,
247249 dtype : torch .dtype = torch .bfloat16 ,
248250 ):
@@ -275,10 +277,30 @@ def __init__(
275277 self .txt_norm1 = nn .LayerNorm (dim , elementwise_affine = False , eps = eps , device = device , dtype = dtype )
276278 self .txt_norm2 = nn .LayerNorm (dim , elementwise_affine = False , eps = eps , device = device , dtype = dtype )
277279 self .txt_mlp = QwenFeedForward (dim = dim , dim_out = dim , device = device , dtype = dtype )
280+ self .zero_cond_t = zero_cond_t
278281
279- def _modulate (self , x , mod_params ):
282+ def _modulate (self , x , mod_params , index = None ):
280283 shift , scale , gate = mod_params .chunk (3 , dim = - 1 )
281- return x * (1 + scale .unsqueeze (1 )) + shift .unsqueeze (1 ), gate .unsqueeze (1 )
284+ if index is not None :
285+ actual_batch = shift .size (0 ) // 2
286+ shift_0 , shift_1 = shift [:actual_batch ], shift [actual_batch :]
287+ scale_0 , scale_1 = scale [:actual_batch ], scale [actual_batch :]
288+ gate_0 , gate_1 = gate [:actual_batch ], gate [actual_batch :]
289+ index_expanded = index .unsqueeze (- 1 )
290+ shift_0_exp = shift_0 .unsqueeze (1 )
291+ shift_1_exp = shift_1 .unsqueeze (1 )
292+ scale_0_exp = scale_0 .unsqueeze (1 )
293+ scale_1_exp = scale_1 .unsqueeze (1 )
294+ gate_0_exp = gate_0 .unsqueeze (1 )
295+ gate_1_exp = gate_1 .unsqueeze (1 )
296+ shift_result = torch .where (index_expanded == 0 , shift_0_exp , shift_1_exp )
297+ scale_result = torch .where (index_expanded == 0 , scale_0_exp , scale_1_exp )
298+ gate_result = torch .where (index_expanded == 0 , gate_0_exp , gate_1_exp )
299+ else :
300+ shift_result = shift .unsqueeze (1 )
301+ scale_result = scale .unsqueeze (1 )
302+ gate_result = gate .unsqueeze (1 )
303+ return x * (1 + scale_result ) + shift_result , gate_result
282304
283305 def forward (
284306 self ,
@@ -288,12 +310,15 @@ def forward(
288310 rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
289311 attn_mask : Optional [torch .Tensor ] = None ,
290312 attn_kwargs : Optional [Dict [str , Any ]] = None ,
313+ modulate_index : Optional [List [int ]] = None ,
291314 ) -> Tuple [torch .Tensor , torch .Tensor ]:
292315 img_mod_attn , img_mod_mlp = self .img_mod (temb ).chunk (2 , dim = - 1 ) # [B, 3*dim] each
316+ if self .zero_cond_t :
317+ temb = torch .chunk (temb , 2 , dim = 0 )[0 ]
293318 txt_mod_attn , txt_mod_mlp = self .txt_mod (temb ).chunk (2 , dim = - 1 ) # [B, 3*dim] each
294319
295320 img_normed = self .img_norm1 (image )
296- img_modulated , img_gate = self ._modulate (img_normed , img_mod_attn )
321+ img_modulated , img_gate = self ._modulate (img_normed , img_mod_attn , modulate_index )
297322
298323 txt_normed = self .txt_norm1 (text )
299324 txt_modulated , txt_gate = self ._modulate (txt_normed , txt_mod_attn )
@@ -305,12 +330,11 @@ def forward(
305330 attn_mask = attn_mask ,
306331 attn_kwargs = attn_kwargs ,
307332 )
308-
309333 image = image + img_gate * img_attn_out
310334 text = text + txt_gate * txt_attn_out
311335
312336 img_normed_2 = self .img_norm2 (image )
313- img_modulated_2 , img_gate_2 = self ._modulate (img_normed_2 , img_mod_mlp )
337+ img_modulated_2 , img_gate_2 = self ._modulate (img_normed_2 , img_mod_mlp , modulate_index )
314338
315339 txt_normed_2 = self .txt_norm2 (text )
316340 txt_modulated_2 , txt_gate_2 = self ._modulate (txt_normed_2 , txt_mod_mlp )
@@ -331,6 +355,7 @@ class QwenImageDiT(PreTrainedModel):
331355 def __init__ (
332356 self ,
333357 num_layers : int = 60 ,
358+ zero_cond_t : bool = False ,
334359 device : str = "cuda:0" ,
335360 dtype : torch .dtype = torch .bfloat16 ,
336361 ):
@@ -351,6 +376,7 @@ def __init__(
351376 dim = 3072 ,
352377 num_attention_heads = 24 ,
353378 attention_head_dim = 128 ,
379+ zero_cond_t = zero_cond_t ,
354380 device = device ,
355381 dtype = dtype ,
356382 )
@@ -359,6 +385,7 @@ def __init__(
359385 )
360386 self .norm_out = AdaLayerNorm (3072 , device = device , dtype = dtype )
361387 self .proj_out = nn .Linear (3072 , 64 , device = device , dtype = dtype )
388+ self .zero_cond_t = zero_cond_t
362389
363390 def patchify (self , hidden_states ):
364391 hidden_states = rearrange (hidden_states , "B C (H P) (W Q) -> B (H W) (C P Q)" , P = 2 , Q = 2 )
@@ -461,6 +488,9 @@ def forward(
461488 use_cfg = use_cfg ,
462489 ),
463490 ):
491+ if self .zero_cond_t :
492+ timestep = torch .cat ([timestep , timestep * 0 ], dim = 0 )
493+ modulate_index = None
464494 conditioning = self .time_text_embed (timestep , image .dtype )
465495 video_fhw = [(1 , h // 2 , w // 2 )] # frame, height, width
466496 text_seq_len = text_seq_lens .max ().item ()
@@ -478,7 +508,12 @@ def forward(
478508 img = self .patchify (img )
479509 image = torch .cat ([image , img ], dim = 1 )
480510 video_fhw += [(1 , edit_h // 2 , edit_w // 2 )]
481-
511+ if self .zero_cond_t :
512+ modulate_index = torch .tensor (
513+ [[0 ] * prod (sample [0 ]) + [1 ] * sum ([prod (s ) for s in sample [1 :]]) for sample in [video_fhw ]],
514+ device = timestep .device ,
515+ dtype = torch .int ,
516+ )
482517 rotary_emb = self .pos_embed (video_fhw , text_seq_len , image .device )
483518
484519 image = self .img_in (image )
@@ -510,7 +545,10 @@ def forward(
510545 rotary_emb = rotary_emb ,
511546 attn_mask = attn_mask ,
512547 attn_kwargs = attn_kwargs ,
548+ modulate_index = modulate_index ,
513549 )
550+ if self .zero_cond_t :
551+ conditioning = conditioning .chunk (2 , dim = 0 )[0 ]
514552 image = self .norm_out (image , conditioning )
515553 image = self .proj_out (image )
516554 (image ,) = sequence_parallel_unshard ((image ,), seq_dims = (1 ,), seq_lens = (image_seq_len ,))
@@ -527,8 +565,9 @@ def from_state_dict(
527565 device : str ,
528566 dtype : torch .dtype ,
529567 num_layers : int = 60 ,
568+ use_zero_cond_t : bool = False ,
530569 ):
531- model = cls (device = "meta" , dtype = dtype , num_layers = num_layers )
570+ model = cls (device = "meta" , dtype = dtype , num_layers = num_layers , zero_cond_t = use_zero_cond_t )
532571 model = model .requires_grad_ (False )
533572 model .load_state_dict (state_dict , assign = True )
534573 model .to (device = device , dtype = dtype , non_blocking = True )
0 commit comments