@@ -156,7 +156,9 @@ def apply_rotary_emb(
156156class HunyuanVideoAttnProcessor2_0 :
157157 def __init__ (self ):
158158 if not hasattr (F , "scaled_dot_product_attention" ):
159- raise ImportError ("HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
159+ raise ImportError (
160+ "HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
161+ )
160162
161163 def __call__ (
162164 self ,
@@ -186,18 +188,24 @@ def __call__(
186188 from ..embeddings import apply_rotary_emb
187189
188190 if attn .add_q_proj is None and encoder_hidden_states is not None :
189- query = torch .cat ([
190- apply_rotary_emb (query [:, :, :- encoder_hidden_states .shape [1 ]], image_rotary_emb ),
191- query [:, :, - encoder_hidden_states .shape [1 ]:],
192- ], dim = 2 )
193- key = torch .cat ([
194- apply_rotary_emb (key [:, :, :- encoder_hidden_states .shape [1 ]], image_rotary_emb ),
195- key [:, :, - encoder_hidden_states .shape [1 ]:],
196- ], dim = 2 )
191+ query = torch .cat (
192+ [
193+ apply_rotary_emb (query [:, :, : - encoder_hidden_states .shape [1 ]], image_rotary_emb ),
194+ query [:, :, - encoder_hidden_states .shape [1 ] :],
195+ ],
196+ dim = 2 ,
197+ )
198+ key = torch .cat (
199+ [
200+ apply_rotary_emb (key [:, :, : - encoder_hidden_states .shape [1 ]], image_rotary_emb ),
201+ key [:, :, - encoder_hidden_states .shape [1 ] :],
202+ ],
203+ dim = 2 ,
204+ )
197205 else :
198206 query = apply_rotary_emb (query , image_rotary_emb )
199207 key = apply_rotary_emb (key , image_rotary_emb )
200-
208+
201209 if attn .add_q_proj is not None and encoder_hidden_states is not None :
202210 encoder_query = attn .add_q_proj (encoder_hidden_states )
203211 encoder_key = attn .add_k_proj (encoder_hidden_states )
@@ -224,8 +232,8 @@ def __call__(
224232
225233 if encoder_hidden_states is not None :
226234 hidden_states , encoder_hidden_states = (
227- hidden_states [:, :- encoder_hidden_states .shape [1 ]],
228- hidden_states [:, - encoder_hidden_states .shape [1 ]:],
235+ hidden_states [:, : - encoder_hidden_states .shape [1 ]],
236+ hidden_states [:, - encoder_hidden_states .shape [1 ] :],
229237 )
230238
231239 if not attn .pre_only :
@@ -513,7 +521,7 @@ def __init__(
513521
514522 hidden_size = num_attention_heads * attention_head_dim
515523 mlp_hidden_dim = int (hidden_size * mlp_width_ratio )
516-
524+
517525 self .hidden_size = hidden_size
518526 self .heads_num = num_attention_heads
519527 self .mlp_hidden_dim = mlp_hidden_dim
@@ -546,14 +554,17 @@ def forward(
546554 ) -> torch .Tensor :
547555 text_seq_length = encoder_hidden_states .shape [1 ]
548556 hidden_states = torch .cat ([hidden_states , encoder_hidden_states ], dim = 1 )
549-
557+
550558 residual = hidden_states
551-
559+
552560 norm_hidden_states , gate = self .norm (hidden_states , emb = temb )
553561 mlp_hidden_states = self .act_mlp (self .proj_mlp (norm_hidden_states ))
554-
555- norm_hidden_states , norm_encoder_hidden_states = norm_hidden_states [:, :- text_seq_length , :], norm_hidden_states [:, - text_seq_length :, :]
556-
562+
563+ norm_hidden_states , norm_encoder_hidden_states = (
564+ norm_hidden_states [:, :- text_seq_length , :],
565+ norm_hidden_states [:, - text_seq_length :, :],
566+ )
567+
557568 # qkv, mlp = torch.split(self.linear1(norm_hidden_states), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
558569 # q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
559570
@@ -584,7 +595,10 @@ def forward(
584595 hidden_states = gate .unsqueeze (1 ) * self .proj_out (hidden_states )
585596 hidden_states = hidden_states + residual
586597
587- hidden_states , encoder_hidden_states = hidden_states [:, :- text_seq_length , :], hidden_states [:, - text_seq_length :, :]
598+ hidden_states , encoder_hidden_states = (
599+ hidden_states [:, :- text_seq_length , :],
600+ hidden_states [:, - text_seq_length :, :],
601+ )
588602 return hidden_states , encoder_hidden_states
589603
590604
@@ -631,7 +645,7 @@ def forward(
631645 norm_encoder_hidden_states , c_gate_msa , c_shift_mlp , c_scale_mlp , c_gate_mlp = self .norm1_context (
632646 encoder_hidden_states , emb = temb
633647 )
634-
648+
635649 img_qkv = self .img_attn_qkv (norm_hidden_states )
636650 img_q , img_k , img_v = rearrange (img_qkv , "B L (K H D) -> K B L H D" , K = 3 , H = self .heads_num )
637651 # Apply QK-Norm if needed
@@ -645,7 +659,7 @@ def forward(
645659 img_qq .shape == img_q .shape and img_kk .shape == img_k .shape
646660 ), f"img_kk: { img_qq .shape } , img_q: { img_q .shape } , img_kk: { img_kk .shape } , img_k: { img_k .shape } "
647661 img_q , img_k = img_qq , img_kk
648-
662+
649663 txt_qkv = self .txt_attn_qkv (norm_encoder_hidden_states )
650664 txt_q , txt_k , txt_v = rearrange (txt_qkv , "B L (K H D) -> K B L H D" , K = 3 , H = self .heads_num )
651665 txt_q = self .txt_attn_q_norm (txt_q ).to (txt_v )
@@ -830,7 +844,6 @@ def forward(
830844 encoder_hidden_states = self .txt_in (encoder_hidden_states , timestep , encoder_attention_mask )
831845
832846 txt_seq_len = encoder_hidden_states .shape [1 ]
833- img_seq_len = hidden_states .shape [1 ]
834847
835848 freqs_cis = (freqs_cos , freqs_sin ) if freqs_cos is not None else None
836849 for _ , block in enumerate (self .transformer_blocks ):
@@ -842,7 +855,7 @@ def forward(
842855 ]
843856
844857 hidden_states , encoder_hidden_states = block (* double_block_args )
845-
858+
846859 for block in self .single_transformer_blocks :
847860 single_block_args = [
848861 hidden_states ,
0 commit comments