2424from ...utils import is_torch_version
2525from ..attention import FeedForward
2626from ..attention_processor import Attention , AttentionProcessor
27+ from ..embeddings import get_1d_rotary_pos_embed
2728from ..modeling_outputs import Transformer2DModelOutput
2829from ..modeling_utils import ModelMixin
2930from ..normalization import AdaLayerNormContinuous , AdaLayerNormZero , AdaLayerNormZeroSingle
@@ -138,26 +139,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
138139class PatchEmbed (nn .Module ):
139140 def __init__ (
140141 self ,
141- patch_size = 16 ,
142- in_chans = 3 ,
143- embed_dim = 768 ,
144- norm_layer = None ,
145- flatten = True ,
146- bias = True ,
147- ):
142+ patch_size : Union [int , Tuple [int , int , int ]] = 16 ,
143+ in_chans : int = 3 ,
144+ embed_dim : int = 768 ,
145+ ) -> None :
148146 super ().__init__ ()
149147
150- patch_size = tuple (patch_size )
151- self .flatten = flatten
152- self .proj = nn .Conv3d (in_chans , embed_dim , kernel_size = patch_size , stride = patch_size , bias = bias )
153- self .norm = norm_layer (embed_dim ) if norm_layer else nn .Identity ()
148+ patch_size = (patch_size , patch_size , patch_size ) if isinstance (patch_size , int ) else patch_size
149+ self .proj = nn .Conv3d (in_chans , embed_dim , kernel_size = patch_size , stride = patch_size )
154150
155- def forward (self , x ):
156- x = self .proj (x )
157- if self .flatten :
158- x = x .flatten (2 ).transpose (1 , 2 ) # BCHW -> BNC
159- x = self .norm (x )
160- return x
151+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
152+ hidden_states = self .proj (hidden_states )
153+ hidden_states = hidden_states .flatten (2 ).transpose (1 , 2 ) # BCFHW -> BNC
154+ return hidden_states
161155
162156
163157class TextProjection (nn .Module ):
@@ -384,6 +378,39 @@ def forward(
384378 return hidden_states
385379
386380
381+ class HunyuanVideoRotaryPosEmbed (nn .Module ):
382+ def __init__ (self , patch_size : int , patch_size_t : int , rope_dim : List [int ], theta : float = 256.0 ) -> None :
383+ super ().__init__ ()
384+
385+ self .patch_size = patch_size
386+ self .patch_size_t = patch_size_t
387+ self .rope_dim = rope_dim
388+ self .theta = theta
389+
390+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
391+ batch_size , num_channels , num_frames , height , width = hidden_states .shape
392+ rope_sizes = [num_frames // self .patch_size_t , height // self .patch_size , width // self .patch_size ]
393+
394+ axes_grids = []
395+ for i in range (3 ):
396+ # Note: The following line diverges from original behaviour. We create the grid on the device, whereas
397+ # original implementation creates it on CPU and then moves it to device. This results in numerical
398+ # differences in layerwise debugging outputs, but visually it is the same.
399+ grid = torch .arange (0 , rope_sizes [i ], device = hidden_states .device , dtype = torch .float32 )
400+ axes_grids .append (grid )
401+ grid = torch .meshgrid (* axes_grids , indexing = "ij" ) # [W, H, T]
402+ grid = torch .stack (grid , dim = 0 ) # [3, W, H, T]
403+
404+ freqs = []
405+ for i in range (3 ):
406+ freq = get_1d_rotary_pos_embed (self .rope_dim [i ], grid [i ].reshape (- 1 ), self .theta , use_real = True )
407+ freqs .append (freq )
408+
409+ freqs_cos = torch .cat ([f [0 ] for f in freqs ], dim = 1 ) # (W * H * T, D / 2)
410+ freqs_sin = torch .cat ([f [1 ] for f in freqs ], dim = 1 ) # (W * H * T, D / 2)
411+ return freqs_cos , freqs_sin
412+
413+
387414class HunyuanVideoSingleTransformerBlock (nn .Module ):
388415 def __init__ (
389416 self ,
@@ -546,12 +573,12 @@ def __init__(
546573 guidance_embeds : bool = True ,
547574 text_embed_dim : int = 4096 ,
548575 text_embed_dim_2 : int = 768 ,
576+ rope_theta : float = 256.0 ,
549577 ) -> None :
550578 super ().__init__ ()
551579
552580 inner_dim = num_attention_heads * attention_head_dim
553581 out_channels = out_channels or in_channels
554- self .rope_dim_list = rope_dim_list
555582
556583 # image projection
557584 self .img_in = PatchEmbed ((patch_size_t , patch_size , patch_size ), in_channels , inner_dim )
@@ -570,6 +597,9 @@ def __init__(
570597 # guidance modulation
571598 self .guidance_in = TimestepEmbedder (inner_dim , nn .SiLU )
572599
600+ # 3. RoPE
601+ self .rope = HunyuanVideoRotaryPosEmbed (patch_size , patch_size_t , rope_dim_list , rope_theta )
602+
573603 self .transformer_blocks = nn .ModuleList (
574604 [
575605 HunyuanVideoTransformerBlock (
@@ -664,8 +694,6 @@ def forward(
664694 encoder_hidden_states : torch .Tensor ,
665695 encoder_attention_mask : torch .Tensor ,
666696 encoder_hidden_states_2 : torch .Tensor ,
667- freqs_cos : Optional [torch .Tensor ] = None ,
668- freqs_sin : Optional [torch .Tensor ] = None ,
669697 guidance : torch .Tensor = None ,
670698 return_dict : bool = True ,
671699 ) -> Union [torch .Tensor , Dict [str , torch .Tensor ]]:
@@ -676,6 +704,8 @@ def forward(
676704 post_patch_height = height // p
677705 post_patch_width = width // p
678706
707+ image_rotary_emb = self .rope (hidden_states )
708+
679709 temb = self .time_in (timestep )
680710 temb = temb + self .vector_in (encoder_hidden_states_2 )
681711 temb = temb + self .guidance_in (guidance )
@@ -691,15 +721,14 @@ def forward(
691721 else lambda x : x
692722 )
693723
694- freqs_cis = (freqs_cos , freqs_sin ) if freqs_cos is not None else None
695724 for _ , block in enumerate (self .transformer_blocks ):
696725 hidden_states , encoder_hidden_states = block_forward (block )(
697- hidden_states , encoder_hidden_states , temb , freqs_cis
726+ hidden_states , encoder_hidden_states , temb , image_rotary_emb
698727 )
699728
700729 for block in self .single_transformer_blocks :
701730 hidden_states , encoder_hidden_states = block_forward (block )(
702- hidden_states , encoder_hidden_states , temb , freqs_cis
731+ hidden_states , encoder_hidden_states , temb , image_rotary_emb
703732 )
704733
705734 hidden_states = self .norm_out (hidden_states , temb )
0 commit comments