2727from  ..attention_processor  import  Attention , AttentionProcessor 
2828from  ..cache_utils  import  CacheMixin 
2929from  ..embeddings  import  (
30-     CombinedTimestepGuidanceTextProjEmbeddings ,
3130    CombinedTimestepTextProjEmbeddings ,
3231    PixArtAlphaTextProjection ,
3332    TimestepEmbedding ,
@@ -179,6 +178,7 @@ def forward(
179178class  HunyuanVideoTokenReplaceAdaLayerNormZero (nn .Module ):
180179    def  __init__ (self , embedding_dim : int , norm_type : str  =  "layer_norm" , bias : bool  =  True ):
181180        super ().__init__ ()
181+ 
182182        self .silu  =  nn .SiLU ()
183183        self .linear  =  nn .Linear (embedding_dim , 6  *  embedding_dim , bias = bias )
184184
@@ -228,8 +228,53 @@ def forward(
228228        )
229229
230230
231+ class  HunyuanVideoTokenReplaceAdaLayerNormZeroSingle (nn .Module ):
232+     def  __init__ (self , embedding_dim : int , norm_type : str  =  "layer_norm" , bias : bool  =  True ):
233+         super ().__init__ ()
234+ 
235+         self .silu  =  nn .SiLU ()
236+         self .linear  =  nn .Linear (embedding_dim , 3  *  embedding_dim , bias = bias )
237+ 
238+         if  norm_type  ==  "layer_norm" :
239+             self .norm  =  nn .LayerNorm (embedding_dim , elementwise_affine = False , eps = 1e-6 )
240+         else :
241+             raise  ValueError (
242+                 f"Unsupported `norm_type` ({ norm_type }  
243+             )
244+ 
245+     def  forward (
246+         self ,
247+         hidden_states : torch .Tensor ,
248+         emb : torch .Tensor ,
249+         token_replace_emb : torch .Tensor ,
250+         first_frame_num_tokens : int ,
251+     ) ->  Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
252+         emb  =  self .linear (self .silu (emb ))
253+         token_replace_emb  =  self .linear (self .silu (token_replace_emb ))
254+ 
255+         shift_msa , scale_msa , gate_msa  =  emb .chunk (3 , dim = 1 )
256+         tr_shift_msa , tr_scale_msa , tr_gate_msa  =  token_replace_emb .chunk (3 , dim = 1 )
257+ 
258+         norm_hidden_states  =  self .norm (hidden_states )
259+         hidden_states_zero  =  (
260+             norm_hidden_states [:, :first_frame_num_tokens ] *  (1  +  tr_scale_msa [:, None ]) +  tr_shift_msa [:, None ]
261+         )
262+         hidden_states_orig  =  (
263+             norm_hidden_states [:, first_frame_num_tokens :] *  (1  +  scale_msa [:, None ]) +  shift_msa [:, None ]
264+         )
265+         hidden_states  =  torch .cat ([hidden_states_zero , hidden_states_orig ], dim = 1 )
266+ 
267+         return  hidden_states , gate_msa , tr_gate_msa 
268+ 
269+ 
231270class  HunyuanVideoConditionEmbedding (nn .Module ):
232-     def  __init__ (self , embedding_dim : int , pooled_projection_dim : int , guidance_embeds : bool , image_condition_type : Optional [str ] =  None ):
271+     def  __init__ (
272+         self ,
273+         embedding_dim : int ,
274+         pooled_projection_dim : int ,
275+         guidance_embeds : bool ,
276+         image_condition_type : Optional [str ] =  None ,
277+     ):
233278        super ().__init__ ()
234279
235280        self .image_condition_type  =  image_condition_type 
@@ -242,7 +287,9 @@ def __init__(self, embedding_dim: int, pooled_projection_dim: int, guidance_embe
242287        if  guidance_embeds :
243288            self .guidance_embedder  =  TimestepEmbedding (in_channels = 256 , time_embed_dim = embedding_dim )
244289
245-     def  forward (self , timestep : torch .Tensor , pooled_projection : torch .Tensor , guidance : Optional [torch .Tensor ] =  None ) ->  Tuple [torch .Tensor , torch .Tensor ]:
290+     def  forward (
291+         self , timestep : torch .Tensor , pooled_projection : torch .Tensor , guidance : Optional [torch .Tensor ] =  None 
292+     ) ->  Tuple [torch .Tensor , torch .Tensor ]:
246293        timesteps_proj  =  self .time_proj (timestep )
247294        timesteps_emb  =  self .timestep_embedder (timesteps_proj .to (dtype = pooled_projection .dtype ))  # (N, D) 
248295        pooled_projections  =  self .text_embedder (pooled_projection )
@@ -480,6 +527,8 @@ def forward(
480527        temb : torch .Tensor ,
481528        attention_mask : Optional [torch .Tensor ] =  None ,
482529        image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] =  None ,
530+         * args ,
531+         ** kwargs ,
483532    ) ->  torch .Tensor :
484533        text_seq_length  =  encoder_hidden_states .shape [1 ]
485534        hidden_states  =  torch .cat ([hidden_states , encoder_hidden_states ], dim = 1 )
@@ -558,6 +607,7 @@ def forward(
558607        temb : torch .Tensor ,
559608        attention_mask : Optional [torch .Tensor ] =  None ,
560609        freqs_cis : Optional [Tuple [torch .Tensor , torch .Tensor ]] =  None ,
610+         * args ,
561611        ** kwargs ,
562612    ) ->  Tuple [torch .Tensor , torch .Tensor ]:
563613        # 1. Input normalization 
@@ -594,6 +644,86 @@ def forward(
594644        return  hidden_states , encoder_hidden_states 
595645
596646
647+ class  HunyuanVideoTokenReplaceSingleTransformerBlock (nn .Module ):
648+     def  __init__ (
649+         self ,
650+         num_attention_heads : int ,
651+         attention_head_dim : int ,
652+         mlp_ratio : float  =  4.0 ,
653+         qk_norm : str  =  "rms_norm" ,
654+     ) ->  None :
655+         super ().__init__ ()
656+ 
657+         hidden_size  =  num_attention_heads  *  attention_head_dim 
658+         mlp_dim  =  int (hidden_size  *  mlp_ratio )
659+ 
660+         self .attn  =  Attention (
661+             query_dim = hidden_size ,
662+             cross_attention_dim = None ,
663+             dim_head = attention_head_dim ,
664+             heads = num_attention_heads ,
665+             out_dim = hidden_size ,
666+             bias = True ,
667+             processor = HunyuanVideoAttnProcessor2_0 (),
668+             qk_norm = qk_norm ,
669+             eps = 1e-6 ,
670+             pre_only = True ,
671+         )
672+ 
673+         self .norm  =  HunyuanVideoTokenReplaceAdaLayerNormZeroSingle (hidden_size , norm_type = "layer_norm" )
674+         self .proj_mlp  =  nn .Linear (hidden_size , mlp_dim )
675+         self .act_mlp  =  nn .GELU (approximate = "tanh" )
676+         self .proj_out  =  nn .Linear (hidden_size  +  mlp_dim , hidden_size )
677+ 
678+     def  forward (
679+         self ,
680+         hidden_states : torch .Tensor ,
681+         encoder_hidden_states : torch .Tensor ,
682+         temb : torch .Tensor ,
683+         attention_mask : Optional [torch .Tensor ] =  None ,
684+         image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] =  None ,
685+         token_replace_emb : torch .Tensor  =  None ,
686+         num_tokens : int  =  None ,
687+     ) ->  torch .Tensor :
688+         text_seq_length  =  encoder_hidden_states .shape [1 ]
689+         hidden_states  =  torch .cat ([hidden_states , encoder_hidden_states ], dim = 1 )
690+ 
691+         residual  =  hidden_states 
692+ 
693+         # 1. Input normalization 
694+         norm_hidden_states , gate , tr_gate  =  self .norm (hidden_states , temb , token_replace_emb , num_tokens )
695+         mlp_hidden_states  =  self .act_mlp (self .proj_mlp (norm_hidden_states ))
696+ 
697+         norm_hidden_states , norm_encoder_hidden_states  =  (
698+             norm_hidden_states [:, :- text_seq_length , :],
699+             norm_hidden_states [:, - text_seq_length :, :],
700+         )
701+ 
702+         # 2. Attention 
703+         attn_output , context_attn_output  =  self .attn (
704+             hidden_states = norm_hidden_states ,
705+             encoder_hidden_states = norm_encoder_hidden_states ,
706+             attention_mask = attention_mask ,
707+             image_rotary_emb = image_rotary_emb ,
708+         )
709+         attn_output  =  torch .cat ([attn_output , context_attn_output ], dim = 1 )
710+ 
711+         # 3. Modulation and residual connection 
712+         hidden_states  =  torch .cat ([attn_output , mlp_hidden_states ], dim = 2 )
713+ 
714+         proj_output  =  self .proj_out (hidden_states )
715+         hidden_states_zero  =  proj_output [:, :num_tokens ] *  tr_gate .unsqueeze (1 )
716+         hidden_states_orig  =  proj_output [:, num_tokens :] *  gate .unsqueeze (1 )
717+         hidden_states  =  torch .cat ([hidden_states_zero , hidden_states_orig ], dim = 1 )
718+         hidden_states  =  hidden_states  +  residual 
719+ 
720+         hidden_states , encoder_hidden_states  =  (
721+             hidden_states [:, :- text_seq_length , :],
722+             hidden_states [:, - text_seq_length :, :],
723+         )
724+         return  hidden_states , encoder_hidden_states 
725+ 
726+ 
597727class  HunyuanVideoTokenReplaceTransformerBlock (nn .Module ):
598728    def  __init__ (
599729        self ,
@@ -664,8 +794,8 @@ def forward(
664794        )
665795
666796        # 3. Modulation and residual connection 
667-         hidden_states_zero  =  hidden_states [:, :num_tokens ] +  attn_output [:, :num_tokens ] *  tr_gate_msa 
668-         hidden_states_orig  =  hidden_states [:, num_tokens :] +  attn_output [:, num_tokens :] *  gate_msa 
797+         hidden_states_zero  =  hidden_states [:, :num_tokens ] +  attn_output [:, :num_tokens ] *  tr_gate_msa . unsqueeze ( 1 ) 
798+         hidden_states_orig  =  hidden_states [:, num_tokens :] +  attn_output [:, num_tokens :] *  gate_msa . unsqueeze ( 1 ) 
669799        hidden_states  =  torch .cat ([hidden_states_zero , hidden_states_orig ], dim = 1 )
670800        encoder_hidden_states  =  encoder_hidden_states  +  context_attn_output  *  c_gate_msa .unsqueeze (1 )
671801
@@ -681,8 +811,8 @@ def forward(
681811        ff_output  =  self .ff (norm_hidden_states )
682812        context_ff_output  =  self .ff_context (norm_encoder_hidden_states )
683813
684-         hidden_states_zero  =  hidden_states [:, :num_tokens ] +  ff_output [:, :num_tokens ] *  tr_gate_mlp 
685-         hidden_states_orig  =  hidden_states [:, num_tokens :] +  ff_output [:, num_tokens :] *  gate_mlp 
814+         hidden_states_zero  =  hidden_states [:, :num_tokens ] +  ff_output [:, :num_tokens ] *  tr_gate_mlp . unsqueeze ( 1 ) 
815+         hidden_states_orig  =  hidden_states [:, num_tokens :] +  ff_output [:, num_tokens :] *  gate_mlp . unsqueeze ( 1 ) 
686816        hidden_states  =  torch .cat ([hidden_states_zero , hidden_states_orig ], dim = 1 )
687817        encoder_hidden_states  =  encoder_hidden_states  +  c_gate_mlp .unsqueeze (1 ) *  context_ff_output 
688818
@@ -802,14 +932,24 @@ def __init__(
802932            )
803933
804934        # 4. Single stream transformer blocks 
805-         self .single_transformer_blocks  =  nn .ModuleList (
806-             [
807-                 HunyuanVideoSingleTransformerBlock (
808-                     num_attention_heads , attention_head_dim , mlp_ratio = mlp_ratio , qk_norm = qk_norm 
809-                 )
810-                 for  _  in  range (num_single_layers )
811-             ]
812-         )
935+         if  image_condition_type  ==  "token_replace" :
936+             self .single_transformer_blocks  =  nn .ModuleList (
937+                 [
938+                     HunyuanVideoTokenReplaceSingleTransformerBlock (
939+                         num_attention_heads , attention_head_dim , mlp_ratio = mlp_ratio , qk_norm = qk_norm 
940+                     )
941+                     for  _  in  range (num_single_layers )
942+                 ]
943+             )
944+         else :
945+             self .single_transformer_blocks  =  nn .ModuleList (
946+                 [
947+                     HunyuanVideoSingleTransformerBlock (
948+                         num_attention_heads , attention_head_dim , mlp_ratio = mlp_ratio , qk_norm = qk_norm 
949+                     )
950+                     for  _  in  range (num_single_layers )
951+                 ]
952+             )
813953
814954        # 5. Output projection 
815955        self .norm_out  =  AdaLayerNormContinuous (inner_dim , inner_dim , elementwise_affine = False , eps = 1e-6 )
@@ -957,6 +1097,8 @@ def forward(
9571097                    temb ,
9581098                    attention_mask ,
9591099                    image_rotary_emb ,
1100+                     token_replace_emb ,
1101+                     first_frame_num_tokens ,
9601102                )
9611103
9621104        else :
@@ -973,7 +1115,13 @@ def forward(
9731115
9741116            for  block  in  self .single_transformer_blocks :
9751117                hidden_states , encoder_hidden_states  =  block (
976-                     hidden_states , encoder_hidden_states , temb , attention_mask , image_rotary_emb 
1118+                     hidden_states ,
1119+                     encoder_hidden_states ,
1120+                     temb ,
1121+                     attention_mask ,
1122+                     image_rotary_emb ,
1123+                     token_replace_emb ,
1124+                     first_frame_num_tokens ,
9771125                )
9781126
9791127        # 5. Output projection 
0 commit comments