@@ -360,13 +360,13 @@ class IndividualTokenRefinerBlock(nn.Module):
360360 def __init__ (
361361 self ,
362362 hidden_size ,
363- heads_num ,
363+ num_attention_heads : int ,
364364 mlp_width_ratio : str = 4.0 ,
365365 mlp_drop_rate : float = 0.0 ,
366366 qkv_bias : bool = True ,
367- ):
367+ ) -> None :
368368 super ().__init__ ()
369- self .heads_num = heads_num
369+ self .heads_num = num_attention_heads
370370
371371 self .norm1 = nn .LayerNorm (hidden_size , elementwise_affine = True , eps = 1e-6 )
372372 self .self_attn_qkv = nn .Linear (hidden_size , hidden_size * 3 , bias = qkv_bias )
@@ -383,25 +383,25 @@ def __init__(
383383
384384 def forward (
385385 self ,
386- x : torch .Tensor ,
387- c : torch .Tensor ,
388- attn_mask : torch .Tensor = None ,
389- ):
390- gate_msa , gate_mlp = self .adaLN_modulation (c ).chunk (2 , dim = 1 )
386+ hidden_states : torch .Tensor ,
387+ temb : torch .Tensor ,
388+ attention_mask : Optional [ torch .Tensor ] = None ,
389+ ) -> torch . Tensor :
390+ gate_msa , gate_mlp = self .adaLN_modulation (temb ).chunk (2 , dim = 1 )
391391
392- norm_x = self .norm1 (x )
392+ norm_x = self .norm1 (hidden_states )
393393 qkv = self .self_attn_qkv (norm_x )
394394 q , k , v = rearrange (qkv , "B L (K H D) -> K B L H D" , K = 3 , H = self .heads_num )
395395
396396 # Self-Attention
397- attn = attention (q , k , v , attn_mask = attn_mask )
397+ attn = attention (q , k , v , attn_mask = attention_mask )
398398
399- x = x + self .self_attn_proj (attn ) * gate_msa .unsqueeze (1 )
399+ hidden_states = hidden_states + self .self_attn_proj (attn ) * gate_msa .unsqueeze (1 )
400400
401401 # FFN Layer
402- x = x + self .mlp (self .norm2 (x )) * gate_mlp .unsqueeze (1 )
402+ hidden_states = hidden_states + self .mlp (self .norm2 (hidden_states )) * gate_mlp .unsqueeze (1 )
403403
404- return x
404+ return hidden_states
405405
406406
407407class IndividualTokenRefiner (nn .Module ):
@@ -419,7 +419,7 @@ def __init__(
419419 [
420420 IndividualTokenRefinerBlock (
421421 hidden_size = hidden_size ,
422- heads_num = heads_num ,
422+ num_attention_heads = heads_num ,
423423 mlp_width_ratio = mlp_width_ratio ,
424424 mlp_drop_rate = mlp_drop_rate ,
425425 qkv_bias = qkv_bias ,
@@ -430,41 +430,34 @@ def __init__(
430430
431431 def forward (
432432 self ,
433- x : torch .Tensor ,
434- c : torch .LongTensor ,
435- mask : Optional [torch .Tensor ] = None ,
433+ hidden_states : torch .Tensor ,
434+ temb : torch .Tensor ,
435+ attention_mask : Optional [torch .Tensor ] = None ,
436436 ):
437437 self_attn_mask = None
438- if mask is not None :
439- batch_size = mask .shape [0 ]
440- seq_len = mask .shape [1 ]
441- mask = mask .to (x .device ).bool ()
442- # batch_size x 1 x seq_len x seq_len
443- self_attn_mask_1 = mask .view (batch_size , 1 , 1 , seq_len ).repeat (1 , 1 , seq_len , 1 )
444- # batch_size x 1 x seq_len x seq_len
438+ if attention_mask is not None :
439+ batch_size = attention_mask .shape [0 ]
440+ seq_len = attention_mask .shape [1 ]
441+ attention_mask = attention_mask .to (hidden_states .device ).bool ()
442+ self_attn_mask_1 = attention_mask .view (batch_size , 1 , 1 , seq_len ).repeat (1 , 1 , seq_len , 1 )
445443 self_attn_mask_2 = self_attn_mask_1 .transpose (2 , 3 )
446- # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
447444 self_attn_mask = (self_attn_mask_1 & self_attn_mask_2 ).bool ()
448- # avoids self-attention weight being NaN for padding tokens
449445 self_attn_mask [:, :, :, 0 ] = True
450446
451447 for block in self .blocks :
452- x = block (x , c , self_attn_mask )
453- return x
448+ hidden_states = block (hidden_states , temb , self_attn_mask )
449+
450+ return hidden_states
454451
455452
456453class SingleTokenRefiner (nn .Module ):
457- """
458- A single token refiner block for llm text embedding refine.
459- """
460-
461454 def __init__ (
462455 self ,
463- in_channels ,
464- hidden_size ,
465- num_attention_heads ,
466- depth ,
467- mlp_width_ratio : float = 4.0 ,
456+ in_channels : int ,
457+ hidden_size : int ,
458+ num_attention_heads : int ,
459+ depth : int ,
460+ mlp_ratio : float = 4.0 ,
468461 mlp_drop_rate : float = 0.0 ,
469462 qkv_bias : bool = True ,
470463 ):
@@ -481,7 +474,7 @@ def __init__(
481474 hidden_size = hidden_size ,
482475 heads_num = num_attention_heads ,
483476 depth = depth ,
484- mlp_width_ratio = mlp_width_ratio ,
477+ mlp_width_ratio = mlp_ratio ,
485478 mlp_drop_rate = mlp_drop_rate ,
486479 qkv_bias = qkv_bias ,
487480 )
@@ -587,28 +580,31 @@ def forward(
587580class HunyuanVideoTransformerBlock (nn .Module ):
588581 def __init__ (
589582 self ,
590- hidden_size : int ,
591- heads_num : int ,
583+ num_attention_heads : int ,
584+ attention_head_dim : int ,
592585 mlp_ratio : float ,
593586 qk_norm : str = "rms_norm" ,
594- ):
587+ ) -> None :
595588 super ().__init__ ()
596589
597- self .heads_num = heads_num
598- head_dim = hidden_size // heads_num
590+ hidden_size = num_attention_heads * attention_head_dim
599591
600592 self .norm1 = AdaLayerNormZero (hidden_size , norm_type = "layer_norm" )
601593 self .norm1_context = AdaLayerNormZero (hidden_size , norm_type = "layer_norm" )
602594
603- self .img_attn_qkv = nn .Linear (hidden_size , hidden_size * 3 )
604- self .img_attn_q_norm = RMSNorm (head_dim , elementwise_affine = True , eps = 1e-6 )
605- self .img_attn_k_norm = RMSNorm (head_dim , elementwise_affine = True , eps = 1e-6 )
606- self .img_attn_proj = nn .Linear (hidden_size , hidden_size )
607-
608- self .txt_attn_qkv = nn .Linear (hidden_size , hidden_size * 3 )
609- self .txt_attn_q_norm = RMSNorm (head_dim , elementwise_affine = True , eps = 1e-6 )
610- self .txt_attn_k_norm = RMSNorm (head_dim , elementwise_affine = True , eps = 1e-6 )
611- self .txt_attn_proj = nn .Linear (hidden_size , hidden_size )
595+ self .attn = Attention (
596+ query_dim = hidden_size ,
597+ cross_attention_dim = None ,
598+ added_kv_proj_dim = hidden_size ,
599+ dim_head = attention_head_dim ,
600+ heads = num_attention_heads ,
601+ out_dim = hidden_size ,
602+ context_pre_only = False ,
603+ bias = True ,
604+ processor = HunyuanVideoAttnProcessor2_0 (),
605+ qk_norm = qk_norm ,
606+ eps = 1e-6 ,
607+ )
612608
613609 self .norm2 = nn .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 )
614610 self .ff = FeedForward (hidden_size , mult = mlp_ratio , activation_fn = "gelu-approximate" )
@@ -627,35 +623,15 @@ def forward(
627623 norm_encoder_hidden_states , c_gate_msa , c_shift_mlp , c_scale_mlp , c_gate_mlp = self .norm1_context (
628624 encoder_hidden_states , emb = temb
629625 )
630-
631- img_qkv = self .img_attn_qkv (norm_hidden_states )
632- img_q , img_k , img_v = rearrange (img_qkv , "B L (K H D) -> K B L H D" , K = 3 , H = self .heads_num )
633- # Apply QK-Norm if needed
634- img_q = self .img_attn_q_norm (img_q ).to (img_v )
635- img_k = self .img_attn_k_norm (img_k ).to (img_v )
636-
637- # Apply RoPE if needed.
638- if freqs_cis is not None :
639- img_qq , img_kk = apply_rotary_emb (img_q , img_k , freqs_cis , head_first = False )
640- assert (
641- img_qq .shape == img_q .shape and img_kk .shape == img_k .shape
642- ), f"img_kk: { img_qq .shape } , img_q: { img_q .shape } , img_kk: { img_kk .shape } , img_k: { img_k .shape } "
643- img_q , img_k = img_qq , img_kk
644-
645- txt_qkv = self .txt_attn_qkv (norm_encoder_hidden_states )
646- txt_q , txt_k , txt_v = rearrange (txt_qkv , "B L (K H D) -> K B L H D" , K = 3 , H = self .heads_num )
647- txt_q = self .txt_attn_q_norm (txt_q ).to (txt_v )
648- txt_k = self .txt_attn_k_norm (txt_k ).to (txt_v )
649-
650- q = torch .cat ((img_q , txt_q ), dim = 1 )
651- k = torch .cat ((img_k , txt_k ), dim = 1 )
652- v = torch .cat ((img_v , txt_v ), dim = 1 )
653- attn = attention (q , k , v )
654-
655- img_attn , txt_attn = attn [:, : hidden_states .shape [1 ]], attn [:, hidden_states .shape [1 ] :]
656-
657- hidden_states = hidden_states + self .img_attn_proj (img_attn ) * gate_msa .unsqueeze (1 )
658- encoder_hidden_states = encoder_hidden_states + self .txt_attn_proj (txt_attn ) * c_gate_msa .unsqueeze (1 )
626+
627+ img_attn , txt_attn = self .attn (
628+ hidden_states = norm_hidden_states ,
629+ encoder_hidden_states = norm_encoder_hidden_states ,
630+ image_rotary_emb = freqs_cis ,
631+ )
632+
633+ hidden_states = hidden_states + img_attn * gate_msa .unsqueeze (1 )
634+ encoder_hidden_states = encoder_hidden_states + txt_attn * c_gate_msa .unsqueeze (1 )
659635
660636 norm_hidden_states = self .norm2 (hidden_states )
661637 norm_hidden_states = norm_hidden_states * (1 + scale_mlp [:, None ]) + shift_mlp [:, None ]
@@ -686,15 +662,14 @@ def __init__(
686662 patch_size_t : int = 1 ,
687663 rope_dim_list : List [int ] = [16 , 56 , 56 ],
688664 qk_norm : str = "rms_norm" ,
689- guidance_embed : bool = True ,
665+ guidance_embeds : bool = True ,
690666 text_embed_dim : int = 4096 ,
691667 text_embed_dim_2 : int = 768 ,
692668 ) -> None :
693669 super ().__init__ ()
694670
695671 inner_dim = num_attention_heads * attention_head_dim
696672 out_channels = out_channels or in_channels
697- self .guidance_embed = guidance_embed
698673 self .rope_dim_list = rope_dim_list
699674
700675 # image projection
@@ -714,7 +689,7 @@ def __init__(
714689
715690 self .transformer_blocks = nn .ModuleList (
716691 [
717- HunyuanVideoTransformerBlock (inner_dim , num_attention_heads , mlp_ratio = mlp_ratio , qk_norm = qk_norm )
692+ HunyuanVideoTransformerBlock (num_attention_heads , attention_head_dim , mlp_ratio = mlp_ratio , qk_norm = qk_norm )
718693 for _ in range (num_layers )
719694 ]
720695 )
@@ -816,18 +791,9 @@ def forward(
816791 post_patch_height = height // p
817792 post_patch_width = width // p
818793
819- # Prepare modulation vectors.
820794 temb = self .time_in (timestep )
821-
822- # text modulation
823795 temb = temb + self .vector_in (encoder_hidden_states_2 )
824-
825- # guidance modulation
826- if self .guidance_embed :
827- if guidance is None :
828- raise ValueError ("Didn't get guidance strength for guidance distilled model." )
829-
830- temb = temb + self .guidance_in (guidance )
796+ temb = temb + self .guidance_in (guidance )
831797
832798 # Embed image and text.
833799 hidden_states = self .img_in (hidden_states )
0 commit comments