2929from  ..embeddings  import  (
3030    CombinedTimestepGuidanceTextProjEmbeddings ,
3131    CombinedTimestepTextProjEmbeddings ,
32+     PixArtAlphaTextProjection ,
33+     TimestepEmbedding ,
34+     Timesteps ,
3235    get_1d_rotary_pos_embed ,
3336)
3437from  ..modeling_outputs  import  Transformer2DModelOutput 
3538from  ..modeling_utils  import  ModelMixin 
36- from  ..normalization  import  AdaLayerNormContinuous , AdaLayerNormZero , AdaLayerNormZeroSingle 
39+ from  ..normalization  import  AdaLayerNormContinuous , AdaLayerNormZero , AdaLayerNormZeroSingle ,  FP32LayerNorm 
3740
3841
3942logger  =  logging .get_logger (__name__ )  # pylint: disable=invalid-name 
@@ -173,6 +176,84 @@ def forward(
173176        return  gate_msa , gate_mlp 
174177
175178
179+ class  HunyuanVideoTokenReplaceAdaLayerNormZero (nn .Module ):
180+     def  __init__ (self , embedding_dim : int , norm_type : str  =  "layer_norm" , bias : bool  =  True ):
181+         super ().__init__ ()
182+         self .silu  =  nn .SiLU ()
183+         self .linear  =  nn .Linear (embedding_dim , 6  *  embedding_dim , bias = bias )
184+ 
185+         if  norm_type  ==  "layer_norm" :
186+             self .norm  =  nn .LayerNorm (embedding_dim , elementwise_affine = False , eps = 1e-6 )
187+         elif  norm_type  ==  "fp32_layer_norm" :
188+             self .norm  =  FP32LayerNorm (embedding_dim , elementwise_affine = False , bias = False )
189+         else :
190+             raise  ValueError (
191+                 f"Unsupported `norm_type` ({ norm_type }  
192+             )
193+ 
194+     def  forward (
195+         self ,
196+         hidden_states : torch .Tensor ,
197+         emb : torch .Tensor ,
198+         token_replace_emb : torch .Tensor ,
199+         first_frame_num_tokens : int ,
200+     ) ->  Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
201+         emb  =  self .linear (self .silu (emb ))
202+         token_replace_emb  =  self .linear (self .silu (token_replace_emb ))
203+ 
204+         shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp  =  emb .chunk (6 , dim = 1 )
205+         tr_shift_msa , tr_scale_msa , tr_gate_msa , tr_shift_mlp , tr_scale_mlp , tr_gate_mlp  =  token_replace_emb .chunk (
206+             6 , dim = 1 
207+         )
208+ 
209+         norm_hidden_states  =  self .norm (hidden_states )
210+         hidden_states_zero  =  (
211+             norm_hidden_states [:, :first_frame_num_tokens ] *  (1  +  tr_scale_msa [:, None ]) +  tr_shift_msa [:, None ]
212+         )
213+         hidden_states_orig  =  (
214+             norm_hidden_states [:, first_frame_num_tokens :] *  (1  +  scale_msa [:, None ]) +  shift_msa [:, None ]
215+         )
216+         hidden_states  =  torch .cat ([hidden_states_zero , hidden_states_orig ], dim = 1 )
217+ 
218+         return  (
219+             hidden_states ,
220+             gate_msa ,
221+             shift_mlp ,
222+             scale_mlp ,
223+             gate_mlp ,
224+             tr_gate_msa ,
225+             tr_shift_mlp ,
226+             tr_scale_mlp ,
227+             tr_gate_mlp ,
228+         )
229+ 
230+ 
231+ class  HunyuanVideoTimestepTextProjEmbeddings (nn .Module ):
232+     def  __init__ (self , embedding_dim : int , pooled_projection_dim : int , image_condition_type : Optional [str ] =  None ):
233+         super ().__init__ ()
234+ 
235+         self .image_condition_type  =  image_condition_type 
236+ 
237+         self .time_proj  =  Timesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
238+         self .timestep_embedder  =  TimestepEmbedding (in_channels = 256 , time_embed_dim = embedding_dim )
239+         self .text_embedder  =  PixArtAlphaTextProjection (pooled_projection_dim , embedding_dim , act_fn = "silu" )
240+ 
241+     def  forward (self , timestep : torch .Tensor , pooled_projection : torch .Tensor ) ->  Tuple [torch .Tensor , torch .Tensor ]:
242+         timesteps_proj  =  self .time_proj (timestep )
243+         timesteps_emb  =  self .timestep_embedder (timesteps_proj .to (dtype = pooled_projection .dtype ))  # (N, D) 
244+         pooled_projections  =  self .text_embedder (pooled_projection )
245+         conditioning  =  timesteps_emb  +  pooled_projections 
246+ 
247+         token_replace_emb  =  None 
248+         if  self .image_condition_type  ==  "token_replace" :
249+             token_replace_timestep  =  torch .zeros_like (timestep )
250+             token_replace_proj  =  self .time_proj (token_replace_timestep )
251+             token_replace_emb  =  self .timestep_embedder (token_replace_proj )
252+             token_replace_emb  =  token_replace_emb  +  conditioning 
253+ 
254+         return  conditioning , token_replace_emb 
255+ 
256+ 
176257class  HunyuanVideoIndividualTokenRefinerBlock (nn .Module ):
177258    def  __init__ (
178259        self ,
@@ -468,6 +549,7 @@ def forward(
468549        temb : torch .Tensor ,
469550        attention_mask : Optional [torch .Tensor ] =  None ,
470551        freqs_cis : Optional [Tuple [torch .Tensor , torch .Tensor ]] =  None ,
552+         ** kwargs ,
471553    ) ->  Tuple [torch .Tensor , torch .Tensor ]:
472554        # 1. Input normalization 
473555        norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp  =  self .norm1 (hidden_states , emb = temb )
@@ -503,6 +585,101 @@ def forward(
503585        return  hidden_states , encoder_hidden_states 
504586
505587
588+ class  HunyuanVideoTokenReplaceTransformerBlock (nn .Module ):
589+     def  __init__ (
590+         self ,
591+         num_attention_heads : int ,
592+         attention_head_dim : int ,
593+         mlp_ratio : float ,
594+         qk_norm : str  =  "rms_norm" ,
595+     ) ->  None :
596+         super ().__init__ ()
597+ 
598+         hidden_size  =  num_attention_heads  *  attention_head_dim 
599+ 
600+         self .norm1  =  HunyuanVideoTokenReplaceAdaLayerNormZero (hidden_size , norm_type = "layer_norm" )
601+         self .norm1_context  =  AdaLayerNormZero (hidden_size , norm_type = "layer_norm" )
602+ 
603+         self .attn  =  Attention (
604+             query_dim = hidden_size ,
605+             cross_attention_dim = None ,
606+             added_kv_proj_dim = hidden_size ,
607+             dim_head = attention_head_dim ,
608+             heads = num_attention_heads ,
609+             out_dim = hidden_size ,
610+             context_pre_only = False ,
611+             bias = True ,
612+             processor = HunyuanVideoAttnProcessor2_0 (),
613+             qk_norm = qk_norm ,
614+             eps = 1e-6 ,
615+         )
616+ 
617+         self .norm2  =  nn .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 )
618+         self .ff  =  FeedForward (hidden_size , mult = mlp_ratio , activation_fn = "gelu-approximate" )
619+ 
620+         self .norm2_context  =  nn .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 )
621+         self .ff_context  =  FeedForward (hidden_size , mult = mlp_ratio , activation_fn = "gelu-approximate" )
622+ 
623+     def  forward (
624+         self ,
625+         hidden_states : torch .Tensor ,
626+         encoder_hidden_states : torch .Tensor ,
627+         temb : torch .Tensor ,
628+         attention_mask : Optional [torch .Tensor ] =  None ,
629+         freqs_cis : Optional [Tuple [torch .Tensor , torch .Tensor ]] =  None ,
630+         token_replace_emb : torch .Tensor  =  None ,
631+         num_tokens : int  =  None ,
632+     ) ->  Tuple [torch .Tensor , torch .Tensor ]:
633+         # 1. Input normalization 
634+         (
635+             norm_hidden_states ,
636+             gate_msa ,
637+             shift_mlp ,
638+             scale_mlp ,
639+             gate_mlp ,
640+             tr_gate_msa ,
641+             tr_shift_mlp ,
642+             tr_scale_mlp ,
643+             tr_gate_mlp ,
644+         ) =  self .norm1 (hidden_states , temb , token_replace_emb , num_tokens )
645+         norm_encoder_hidden_states , c_gate_msa , c_shift_mlp , c_scale_mlp , c_gate_mlp  =  self .norm1_context (
646+             encoder_hidden_states , emb = temb 
647+         )
648+ 
649+         # 2. Joint attention 
650+         attn_output , context_attn_output  =  self .attn (
651+             hidden_states = norm_hidden_states ,
652+             encoder_hidden_states = norm_encoder_hidden_states ,
653+             attention_mask = attention_mask ,
654+             image_rotary_emb = freqs_cis ,
655+         )
656+ 
657+         # 3. Modulation and residual connection 
658+         hidden_states_zero  =  hidden_states [:, :num_tokens ] +  attn_output [:, :num_tokens ] *  tr_gate_msa 
659+         hidden_states_orig  =  hidden_states [:, num_tokens :] +  attn_output [:, num_tokens :] *  gate_msa 
660+         hidden_states  =  torch .cat ([hidden_states_zero , hidden_states_orig ], dim = 1 )
661+         encoder_hidden_states  =  encoder_hidden_states  +  context_attn_output  *  c_gate_msa .unsqueeze (1 )
662+ 
663+         norm_hidden_states  =  self .norm2 (hidden_states )
664+         norm_encoder_hidden_states  =  self .norm2_context (encoder_hidden_states )
665+ 
666+         hidden_states_zero  =  norm_hidden_states [:, :num_tokens ] *  (1  +  tr_scale_mlp [:, None ]) +  tr_shift_mlp [:, None ]
667+         hidden_states_orig  =  norm_hidden_states [:, num_tokens :] *  (1  +  scale_mlp [:, None ]) +  shift_mlp [:, None ]
668+         hidden_states  =  torch .cat ([hidden_states_zero , hidden_states_orig ], dim = 1 )
669+         norm_encoder_hidden_states  =  norm_encoder_hidden_states  *  (1  +  c_scale_mlp [:, None ]) +  c_shift_mlp [:, None ]
670+ 
671+         # 4. Feed-forward 
672+         ff_output  =  self .ff (norm_hidden_states )
673+         context_ff_output  =  self .ff_context (norm_encoder_hidden_states )
674+ 
675+         hidden_states_zero  =  hidden_states [:, :num_tokens ] +  ff_output [:, :num_tokens ] *  tr_gate_mlp 
676+         hidden_states_orig  =  hidden_states [:, num_tokens :] +  ff_output [:, num_tokens :] *  gate_mlp 
677+         hidden_states  =  torch .cat ([hidden_states_zero , hidden_states_orig ], dim = 1 )
678+         encoder_hidden_states  =  encoder_hidden_states  +  c_gate_mlp .unsqueeze (1 ) *  context_ff_output 
679+ 
680+         return  hidden_states , encoder_hidden_states 
681+ 
682+ 
506683class  HunyuanVideoTransformer3DModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin ):
507684    r""" 
508685    A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). 
@@ -570,8 +747,10 @@ def __init__(
570747        pooled_projection_dim : int  =  768 ,
571748        rope_theta : float  =  256.0 ,
572749        rope_axes_dim : Tuple [int ] =  (16 , 56 , 56 ),
750+         image_condition_type : Optional [str ] =  None ,
573751    ) ->  None :
574752        super ().__init__ ()
753+         assert  image_condition_type  is  None  or  image_condition_type  in  ["latent_concat" , "token_replace" ]
575754
576755        inner_dim  =  num_attention_heads  *  attention_head_dim 
577756        out_channels  =  out_channels  or  in_channels 
@@ -585,20 +764,32 @@ def __init__(
585764        if  guidance_embeds :
586765            self .time_text_embed  =  CombinedTimestepGuidanceTextProjEmbeddings (inner_dim , pooled_projection_dim )
587766        else :
588-             self .time_text_embed  =  CombinedTimestepTextProjEmbeddings (inner_dim , pooled_projection_dim )
767+             self .time_text_embed  =  HunyuanVideoTimestepTextProjEmbeddings (
768+                 inner_dim , pooled_projection_dim , image_condition_type 
769+             )
589770
590771        # 2. RoPE 
591772        self .rope  =  HunyuanVideoRotaryPosEmbed (patch_size , patch_size_t , rope_axes_dim , rope_theta )
592773
593774        # 3. Dual stream transformer blocks 
594-         self .transformer_blocks  =  nn .ModuleList (
595-             [
596-                 HunyuanVideoTransformerBlock (
597-                     num_attention_heads , attention_head_dim , mlp_ratio = mlp_ratio , qk_norm = qk_norm 
598-                 )
599-                 for  _  in  range (num_layers )
600-             ]
601-         )
775+         if  image_condition_type  ==  "token_replace" :
776+             self .transformer_blocks  =  nn .ModuleList (
777+                 [
778+                     HunyuanVideoTokenReplaceTransformerBlock (
779+                         num_attention_heads , attention_head_dim , mlp_ratio = mlp_ratio , qk_norm = qk_norm 
780+                     )
781+                     for  _  in  range (num_layers )
782+                 ]
783+             )
784+         else :
785+             self .transformer_blocks  =  nn .ModuleList (
786+                 [
787+                     HunyuanVideoTransformerBlock (
788+                         num_attention_heads , attention_head_dim , mlp_ratio = mlp_ratio , qk_norm = qk_norm 
789+                     )
790+                     for  _  in  range (num_layers )
791+                 ]
792+             )
602793
603794        # 4. Single stream transformer blocks 
604795        self .single_transformer_blocks  =  nn .ModuleList (
@@ -707,6 +898,7 @@ def forward(
707898        post_patch_num_frames  =  num_frames  //  p_t 
708899        post_patch_height  =  height  //  p 
709900        post_patch_width  =  width  //  p 
901+         first_frame_num_tokens  =  1  *  post_patch_height  *  post_patch_width 
710902
711903        # 1. RoPE 
712904        image_rotary_emb  =  self .rope (hidden_states )
@@ -715,7 +907,7 @@ def forward(
715907        if  self .config .guidance_embeds :
716908            temb  =  self .time_text_embed (timestep , guidance , pooled_projections )
717909        else :
718-             temb  =  self .time_text_embed (timestep , pooled_projections )
910+             temb ,  token_replace_emb  =  self .time_text_embed (timestep , pooled_projections )
719911
720912        hidden_states  =  self .x_embedder (hidden_states )
721913        encoder_hidden_states  =  self .context_embedder (encoder_hidden_states , timestep , encoder_attention_mask )
@@ -746,6 +938,8 @@ def forward(
746938                    temb ,
747939                    attention_mask ,
748940                    image_rotary_emb ,
941+                     token_replace_emb ,
942+                     first_frame_num_tokens ,
749943                )
750944
751945            for  block  in  self .single_transformer_blocks :
@@ -761,7 +955,13 @@ def forward(
761955        else :
762956            for  block  in  self .transformer_blocks :
763957                hidden_states , encoder_hidden_states  =  block (
764-                     hidden_states , encoder_hidden_states , temb , attention_mask , image_rotary_emb 
958+                     hidden_states ,
959+                     encoder_hidden_states ,
960+                     temb ,
961+                     attention_mask ,
962+                     image_rotary_emb ,
963+                     token_replace_emb ,
964+                     first_frame_num_tokens ,
765965                )
766966
767967            for  block  in  self .single_transformer_blocks :
0 commit comments