@@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
6161
6262
6363class QwenTimestepProjEmbeddings (nn .Module ):
64- def __init__ (self , embedding_dim , pooled_projection_dim , dtype = None , device = None , operations = None ):
64+ def __init__ (self , embedding_dim , pooled_projection_dim , use_additional_t_cond = False , dtype = None , device = None , operations = None ):
6565 super ().__init__ ()
6666 self .time_proj = Timesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 , scale = 1000 )
6767 self .timestep_embedder = TimestepEmbedding (
@@ -72,9 +72,19 @@ def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None
7272 operations = operations
7373 )
7474
75- def forward (self , timestep , hidden_states ):
75+ self .use_additional_t_cond = use_additional_t_cond
76+ if self .use_additional_t_cond :
77+ self .addition_t_embedding = operations .Embedding (2 , embedding_dim , device = device , dtype = dtype )
78+
79+ def forward (self , timestep , hidden_states , addition_t_cond = None ):
7680 timesteps_proj = self .time_proj (timestep )
7781 timesteps_emb = self .timestep_embedder (timesteps_proj .to (dtype = hidden_states .dtype ))
82+
83+ if self .use_additional_t_cond :
84+ if addition_t_cond is None :
85+ addition_t_cond = torch .zeros ((timesteps_emb .shape [0 ]), device = timesteps_emb .device , dtype = torch .long )
86+ timesteps_emb += self .addition_t_embedding (addition_t_cond , out_dtype = timesteps_emb .dtype )
87+
7888 return timesteps_emb
7989
8090
@@ -320,11 +330,11 @@ def __init__(
320330 num_attention_heads : int = 24 ,
321331 joint_attention_dim : int = 3584 ,
322332 pooled_projection_dim : int = 768 ,
323- guidance_embeds : bool = False ,
324333 axes_dims_rope : Tuple [int , int , int ] = (16 , 56 , 56 ),
325334 default_ref_method = "index" ,
326335 image_model = None ,
327336 final_layer = True ,
337+ use_additional_t_cond = False ,
328338 dtype = None ,
329339 device = None ,
330340 operations = None ,
@@ -342,6 +352,7 @@ def __init__(
342352 self .time_text_embed = QwenTimestepProjEmbeddings (
343353 embedding_dim = self .inner_dim ,
344354 pooled_projection_dim = pooled_projection_dim ,
355+ use_additional_t_cond = use_additional_t_cond ,
345356 dtype = dtype ,
346357 device = device ,
347358 operations = operations
@@ -375,36 +386,42 @@ def process_img(self, x, index=0, h_offset=0, w_offset=0):
375386 patch_size = self .patch_size
376387 hidden_states = comfy .ldm .common_dit .pad_to_patch_size (x , (1 , self .patch_size , self .patch_size ))
377388 orig_shape = hidden_states .shape
378- hidden_states = hidden_states .view (orig_shape [0 ], orig_shape [1 ], orig_shape [- 2 ] // 2 , 2 , orig_shape [- 1 ] // 2 , 2 )
379- hidden_states = hidden_states .permute (0 , 2 , 4 , 1 , 3 , 5 )
380- hidden_states = hidden_states .reshape (orig_shape [0 ], (orig_shape [- 2 ] // 2 ) * (orig_shape [- 1 ] // 2 ), orig_shape [1 ] * 4 )
389+ hidden_states = hidden_states .view (orig_shape [0 ], orig_shape [1 ], orig_shape [- 3 ], orig_shape [- 2 ] // 2 , 2 , orig_shape [- 1 ] // 2 , 2 )
390+ hidden_states = hidden_states .permute (0 , 2 , 3 , 5 , 1 , 4 , 6 )
391+ hidden_states = hidden_states .reshape (orig_shape [0 ], orig_shape [- 3 ] * (orig_shape [- 2 ] // 2 ) * (orig_shape [- 1 ] // 2 ), orig_shape [1 ] * 4 )
392+ t_len = t
381393 h_len = ((h + (patch_size // 2 )) // patch_size )
382394 w_len = ((w + (patch_size // 2 )) // patch_size )
383395
384396 h_offset = ((h_offset + (patch_size // 2 )) // patch_size )
385397 w_offset = ((w_offset + (patch_size // 2 )) // patch_size )
386398
387- img_ids = torch .zeros ((h_len , w_len , 3 ), device = x .device )
388- img_ids [:, :, 0 ] = img_ids [:, :, 1 ] + index
389- img_ids [:, :, 1 ] = img_ids [:, :, 1 ] + torch .linspace (h_offset , h_len - 1 + h_offset , steps = h_len , device = x .device , dtype = x .dtype ).unsqueeze (1 ) - (h_len // 2 )
390- img_ids [:, :, 2 ] = img_ids [:, :, 2 ] + torch .linspace (w_offset , w_len - 1 + w_offset , steps = w_len , device = x .device , dtype = x .dtype ).unsqueeze (0 ) - (w_len // 2 )
391- return hidden_states , repeat (img_ids , "h w c -> b (h w) c" , b = bs ), orig_shape
399+ img_ids = torch .zeros ((t_len , h_len , w_len , 3 ), device = x .device )
400+
401+ if t_len > 1 :
402+ img_ids [:, :, :, 0 ] = img_ids [:, :, :, 0 ] + torch .linspace (0 , t_len - 1 , steps = t_len , device = x .device , dtype = x .dtype ).unsqueeze (1 ).unsqueeze (1 )
403+ else :
404+ img_ids [:, :, :, 0 ] = img_ids [:, :, :, 0 ] + index
405+
406+ img_ids [:, :, :, 1 ] = img_ids [:, :, :, 1 ] + torch .linspace (h_offset , h_len - 1 + h_offset , steps = h_len , device = x .device , dtype = x .dtype ).unsqueeze (1 ).unsqueeze (0 ) - (h_len // 2 )
407+ img_ids [:, :, :, 2 ] = img_ids [:, :, :, 2 ] + torch .linspace (w_offset , w_len - 1 + w_offset , steps = w_len , device = x .device , dtype = x .dtype ).unsqueeze (0 ).unsqueeze (0 ) - (w_len // 2 )
408+ return hidden_states , repeat (img_ids , "t h w c -> b (t h w) c" , b = bs ), orig_shape
392409
393- def forward (self , x , timestep , context , attention_mask = None , guidance = None , ref_latents = None , transformer_options = {}, ** kwargs ):
410+ def forward (self , x , timestep , context , attention_mask = None , ref_latents = None , additional_t_cond = None , transformer_options = {}, ** kwargs ):
394411 return comfy .patcher_extension .WrapperExecutor .new_class_executor (
395412 self ._forward ,
396413 self ,
397414 comfy .patcher_extension .get_all_wrappers (comfy .patcher_extension .WrappersMP .DIFFUSION_MODEL , transformer_options )
398- ).execute (x , timestep , context , attention_mask , guidance , ref_latents , transformer_options , ** kwargs )
415+ ).execute (x , timestep , context , attention_mask , ref_latents , additional_t_cond , transformer_options , ** kwargs )
399416
400417 def _forward (
401418 self ,
402419 x ,
403420 timesteps ,
404421 context ,
405422 attention_mask = None ,
406- guidance : torch .Tensor = None ,
407423 ref_latents = None ,
424+ additional_t_cond = None ,
408425 transformer_options = {},
409426 control = None ,
410427 ** kwargs
@@ -423,12 +440,17 @@ def _forward(
423440 index = 0
424441 ref_method = kwargs .get ("ref_latents_method" , self .default_ref_method )
425442 index_ref_method = (ref_method == "index" ) or (ref_method == "index_timestep_zero" )
443+ negative_ref_method = ref_method == "negative_index"
426444 timestep_zero = ref_method == "index_timestep_zero"
427445 for ref in ref_latents :
428446 if index_ref_method :
429447 index += 1
430448 h_offset = 0
431449 w_offset = 0
450+ elif negative_ref_method :
451+ index -= 1
452+ h_offset = 0
453+ w_offset = 0
432454 else :
433455 index = 1
434456 h_offset = 0
@@ -458,14 +480,7 @@ def _forward(
458480 encoder_hidden_states = self .txt_norm (encoder_hidden_states )
459481 encoder_hidden_states = self .txt_in (encoder_hidden_states )
460482
461- if guidance is not None :
462- guidance = guidance * 1000
463-
464- temb = (
465- self .time_text_embed (timestep , hidden_states )
466- if guidance is None
467- else self .time_text_embed (timestep , guidance , hidden_states )
468- )
483+ temb = self .time_text_embed (timestep , hidden_states , additional_t_cond )
469484
470485 patches_replace = transformer_options .get ("patches_replace" , {})
471486 patches = transformer_options .get ("patches" , {})
@@ -513,6 +528,6 @@ def block_wrap(args):
513528 hidden_states = self .norm_out (hidden_states , temb )
514529 hidden_states = self .proj_out (hidden_states )
515530
516- hidden_states = hidden_states [:, :num_embeds ].view (orig_shape [0 ], orig_shape [- 2 ] // 2 , orig_shape [- 1 ] // 2 , orig_shape [1 ], 2 , 2 )
517- hidden_states = hidden_states .permute (0 , 3 , 1 , 4 , 2 , 5 )
531+ hidden_states = hidden_states [:, :num_embeds ].view (orig_shape [0 ], orig_shape [- 3 ], orig_shape [ - 2 ] // 2 , orig_shape [- 1 ] // 2 , orig_shape [1 ], 2 , 2 )
532+ hidden_states = hidden_states .permute (0 , 4 , 1 , 2 , 5 , 3 , 6 )
518533 return hidden_states .reshape (orig_shape )[:, :, :, :x .shape [- 2 ], :x .shape [- 1 ]]
0 commit comments