@@ -906,6 +906,177 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
906906        return  self .processor (self , hidden_states )
907907
908908
909+ class  MochiAttention (nn .Module ):
910+     def  __init__ (
911+         self ,
912+         query_dim : int ,
913+         added_kv_proj_dim : int ,
914+         processor : "MochiAttnProcessor2_0" ,
915+         heads : int  =  8 ,
916+         dim_head : int  =  64 ,
917+         dropout : float  =  0.0 ,
918+         bias : bool  =  False ,
919+         added_proj_bias : bool  =  True ,
920+         out_dim : Optional [int ] =  None ,
921+         out_context_dim : Optional [int ] =  None ,
922+         out_bias : bool  =  True ,
923+         context_pre_only : bool  =  False ,
924+         eps : float  =  1e-5 ,
925+     ):
926+         super ().__init__ ()
927+         from  .normalization  import  MochiRMSNorm 
928+ 
929+         self .inner_dim  =  out_dim  if  out_dim  is  not None  else  dim_head  *  heads 
930+         self .out_dim  =  out_dim  if  out_dim  is  not None  else  query_dim 
931+         self .out_context_dim  =  out_context_dim  if  out_context_dim  else  query_dim 
932+         self .context_pre_only  =  context_pre_only 
933+ 
934+         self .heads  =  out_dim  //  dim_head  if  out_dim  is  not None  else  heads 
935+ 
936+         self .norm_q  =  MochiRMSNorm (dim_head , eps , True )
937+         self .norm_k  =  MochiRMSNorm (dim_head , eps , True )
938+         self .norm_added_q  =  MochiRMSNorm (dim_head , eps , True )
939+         self .norm_added_k  =  MochiRMSNorm (dim_head , eps , True )
940+ 
941+         self .to_q  =  nn .Linear (query_dim , self .inner_dim , bias = bias )
942+         self .to_k  =  nn .Linear (query_dim , self .inner_dim , bias = bias )
943+         self .to_v  =  nn .Linear (query_dim , self .inner_dim , bias = bias )
944+ 
945+         self .add_k_proj  =  nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
946+         self .add_v_proj  =  nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
947+         if  self .context_pre_only  is  not None :
948+             self .add_q_proj  =  nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
949+ 
950+         self .to_out  =  nn .ModuleList ([])
951+         self .to_out .append (nn .Linear (self .inner_dim , self .out_dim , bias = out_bias ))
952+         self .to_out .append (nn .Dropout (dropout ))
953+ 
954+         if  not  self .context_pre_only :
955+             self .to_add_out  =  nn .Linear (self .inner_dim , self .out_context_dim , bias = out_bias )
956+ 
957+         self .processor  =  processor 
958+ 
959+     def  forward (
960+         self ,
961+         hidden_states : torch .Tensor ,
962+         encoder_hidden_states : Optional [torch .Tensor ] =  None ,
963+         attention_mask : Optional [torch .Tensor ] =  None ,
964+         ** kwargs ,
965+     ):
966+         return  self .processor (
967+             self ,
968+             hidden_states ,
969+             encoder_hidden_states = encoder_hidden_states ,
970+             attention_mask = attention_mask ,
971+             ** kwargs ,
972+         )
973+ 
974+ 
975+ class  MochiAttnProcessor2_0 :
976+     """Attention processor used in Mochi.""" 
977+ 
978+     def  __init__ (self ):
979+         if  not  hasattr (F , "scaled_dot_product_attention" ):
980+             raise  ImportError ("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
981+ 
982+     def  __call__ (
983+         self ,
984+         attn : "MochiAttention" ,
985+         hidden_states : torch .Tensor ,
986+         encoder_hidden_states : torch .Tensor ,
987+         attention_mask : torch .Tensor ,
988+         image_rotary_emb : Optional [torch .Tensor ] =  None ,
989+     ) ->  torch .Tensor :
990+         query  =  attn .to_q (hidden_states )
991+         key  =  attn .to_k (hidden_states )
992+         value  =  attn .to_v (hidden_states )
993+ 
994+         query  =  query .unflatten (2 , (attn .heads , - 1 ))
995+         key  =  key .unflatten (2 , (attn .heads , - 1 ))
996+         value  =  value .unflatten (2 , (attn .heads , - 1 ))
997+ 
998+         if  attn .norm_q  is  not None :
999+             query  =  attn .norm_q (query )
1000+         if  attn .norm_k  is  not None :
1001+             key  =  attn .norm_k (key )
1002+ 
1003+         encoder_query  =  attn .add_q_proj (encoder_hidden_states )
1004+         encoder_key  =  attn .add_k_proj (encoder_hidden_states )
1005+         encoder_value  =  attn .add_v_proj (encoder_hidden_states )
1006+ 
1007+         encoder_query  =  encoder_query .unflatten (2 , (attn .heads , - 1 ))
1008+         encoder_key  =  encoder_key .unflatten (2 , (attn .heads , - 1 ))
1009+         encoder_value  =  encoder_value .unflatten (2 , (attn .heads , - 1 ))
1010+ 
1011+         if  attn .norm_added_q  is  not None :
1012+             encoder_query  =  attn .norm_added_q (encoder_query )
1013+         if  attn .norm_added_k  is  not None :
1014+             encoder_key  =  attn .norm_added_k (encoder_key )
1015+ 
1016+         if  image_rotary_emb  is  not None :
1017+ 
1018+             def  apply_rotary_emb (x , freqs_cos , freqs_sin ):
1019+                 x_even  =  x [..., 0 ::2 ].float ()
1020+                 x_odd  =  x [..., 1 ::2 ].float ()
1021+ 
1022+                 cos  =  (x_even  *  freqs_cos  -  x_odd  *  freqs_sin ).to (x .dtype )
1023+                 sin  =  (x_even  *  freqs_sin  +  x_odd  *  freqs_cos ).to (x .dtype )
1024+ 
1025+                 return  torch .stack ([cos , sin ], dim = - 1 ).flatten (- 2 )
1026+ 
1027+             query  =  apply_rotary_emb (query , * image_rotary_emb )
1028+             key  =  apply_rotary_emb (key , * image_rotary_emb )
1029+ 
1030+         query , key , value  =  query .transpose (1 , 2 ), key .transpose (1 , 2 ), value .transpose (1 , 2 )
1031+         encoder_query , encoder_key , encoder_value  =  (
1032+             encoder_query .transpose (1 , 2 ),
1033+             encoder_key .transpose (1 , 2 ),
1034+             encoder_value .transpose (1 , 2 ),
1035+         )
1036+ 
1037+         sequence_length  =  query .size (2 )
1038+         encoder_sequence_length  =  encoder_query .size (2 )
1039+         total_length  =  sequence_length  +  encoder_sequence_length 
1040+ 
1041+         batch_size , heads , _ , dim  =  query .shape 
1042+         attn_outputs  =  []
1043+         for  idx  in  range (batch_size ):
1044+             mask  =  attention_mask [idx ][None , :]
1045+             valid_prompt_token_indices  =  torch .nonzero (mask .flatten (), as_tuple = False ).flatten ()
1046+ 
1047+             valid_encoder_query  =  encoder_query [idx  : idx  +  1 , :, valid_prompt_token_indices , :]
1048+             valid_encoder_key  =  encoder_key [idx  : idx  +  1 , :, valid_prompt_token_indices , :]
1049+             valid_encoder_value  =  encoder_value [idx  : idx  +  1 , :, valid_prompt_token_indices , :]
1050+ 
1051+             valid_query  =  torch .cat ([query [idx  : idx  +  1 ], valid_encoder_query ], dim = 2 )
1052+             valid_key  =  torch .cat ([key [idx  : idx  +  1 ], valid_encoder_key ], dim = 2 )
1053+             valid_value  =  torch .cat ([value [idx  : idx  +  1 ], valid_encoder_value ], dim = 2 )
1054+ 
1055+             attn_output  =  F .scaled_dot_product_attention (
1056+                 valid_query , valid_key , valid_value , dropout_p = 0.0 , is_causal = False 
1057+             )
1058+             valid_sequence_length  =  attn_output .size (2 )
1059+             attn_output  =  F .pad (attn_output , (0 , 0 , 0 , total_length  -  valid_sequence_length ))
1060+             attn_outputs .append (attn_output )
1061+ 
1062+         hidden_states  =  torch .cat (attn_outputs , dim = 0 )
1063+         hidden_states  =  hidden_states .transpose (1 , 2 ).flatten (2 , 3 )
1064+ 
1065+         hidden_states , encoder_hidden_states  =  hidden_states .split_with_sizes (
1066+             (sequence_length , encoder_sequence_length ), dim = 1 
1067+         )
1068+ 
1069+         # linear proj 
1070+         hidden_states  =  attn .to_out [0 ](hidden_states )
1071+         # dropout 
1072+         hidden_states  =  attn .to_out [1 ](hidden_states )
1073+ 
1074+         if  hasattr (attn , "to_add_out" ):
1075+             encoder_hidden_states  =  attn .to_add_out (encoder_hidden_states )
1076+ 
1077+         return  hidden_states , encoder_hidden_states 
1078+ 
1079+ 
9091080class  AttnProcessor :
9101081    r""" 
9111082    Default processor for performing attention-related computations. 
@@ -3868,94 +4039,6 @@ def __call__(
38684039        return  hidden_states 
38694040
38704041
3871- class  MochiAttnProcessor2_0 :
3872-     """Attention processor used in Mochi.""" 
3873- 
3874-     def  __init__ (self ):
3875-         if  not  hasattr (F , "scaled_dot_product_attention" ):
3876-             raise  ImportError ("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
3877- 
3878-     def  __call__ (
3879-         self ,
3880-         attn : Attention ,
3881-         hidden_states : torch .Tensor ,
3882-         encoder_hidden_states : torch .Tensor ,
3883-         attention_mask : Optional [torch .Tensor ] =  None ,
3884-         image_rotary_emb : Optional [torch .Tensor ] =  None ,
3885-     ) ->  torch .Tensor :
3886-         query  =  attn .to_q (hidden_states )
3887-         key  =  attn .to_k (hidden_states )
3888-         value  =  attn .to_v (hidden_states )
3889- 
3890-         query  =  query .unflatten (2 , (attn .heads , - 1 ))
3891-         key  =  key .unflatten (2 , (attn .heads , - 1 ))
3892-         value  =  value .unflatten (2 , (attn .heads , - 1 ))
3893- 
3894-         if  attn .norm_q  is  not None :
3895-             query  =  attn .norm_q (query )
3896-         if  attn .norm_k  is  not None :
3897-             key  =  attn .norm_k (key )
3898- 
3899-         encoder_query  =  attn .add_q_proj (encoder_hidden_states )
3900-         encoder_key  =  attn .add_k_proj (encoder_hidden_states )
3901-         encoder_value  =  attn .add_v_proj (encoder_hidden_states )
3902- 
3903-         encoder_query  =  encoder_query .unflatten (2 , (attn .heads , - 1 ))
3904-         encoder_key  =  encoder_key .unflatten (2 , (attn .heads , - 1 ))
3905-         encoder_value  =  encoder_value .unflatten (2 , (attn .heads , - 1 ))
3906- 
3907-         if  attn .norm_added_q  is  not None :
3908-             encoder_query  =  attn .norm_added_q (encoder_query )
3909-         if  attn .norm_added_k  is  not None :
3910-             encoder_key  =  attn .norm_added_k (encoder_key )
3911- 
3912-         if  image_rotary_emb  is  not None :
3913- 
3914-             def  apply_rotary_emb (x , freqs_cos , freqs_sin ):
3915-                 x_even  =  x [..., 0 ::2 ].float ()
3916-                 x_odd  =  x [..., 1 ::2 ].float ()
3917- 
3918-                 cos  =  (x_even  *  freqs_cos  -  x_odd  *  freqs_sin ).to (x .dtype )
3919-                 sin  =  (x_even  *  freqs_sin  +  x_odd  *  freqs_cos ).to (x .dtype )
3920- 
3921-                 return  torch .stack ([cos , sin ], dim = - 1 ).flatten (- 2 )
3922- 
3923-             query  =  apply_rotary_emb (query , * image_rotary_emb )
3924-             key  =  apply_rotary_emb (key , * image_rotary_emb )
3925- 
3926-         query , key , value  =  query .transpose (1 , 2 ), key .transpose (1 , 2 ), value .transpose (1 , 2 )
3927-         encoder_query , encoder_key , encoder_value  =  (
3928-             encoder_query .transpose (1 , 2 ),
3929-             encoder_key .transpose (1 , 2 ),
3930-             encoder_value .transpose (1 , 2 ),
3931-         )
3932- 
3933-         sequence_length  =  query .size (2 )
3934-         encoder_sequence_length  =  encoder_query .size (2 )
3935- 
3936-         query  =  torch .cat ([query , encoder_query ], dim = 2 )
3937-         key  =  torch .cat ([key , encoder_key ], dim = 2 )
3938-         value  =  torch .cat ([value , encoder_value ], dim = 2 )
3939- 
3940-         hidden_states  =  F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
3941-         hidden_states  =  hidden_states .transpose (1 , 2 ).flatten (2 , 3 )
3942-         hidden_states  =  hidden_states .to (query .dtype )
3943- 
3944-         hidden_states , encoder_hidden_states  =  hidden_states .split_with_sizes (
3945-             (sequence_length , encoder_sequence_length ), dim = 1 
3946-         )
3947- 
3948-         # linear proj 
3949-         hidden_states  =  attn .to_out [0 ](hidden_states )
3950-         # dropout 
3951-         hidden_states  =  attn .to_out [1 ](hidden_states )
3952- 
3953-         if  getattr (attn , "to_add_out" , None ) is  not None :
3954-             encoder_hidden_states  =  attn .to_add_out (encoder_hidden_states )
3955- 
3956-         return  hidden_states , encoder_hidden_states 
3957- 
3958- 
39594042class  FusedAttnProcessor2_0 :
39604043    r""" 
39614044    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses 
@@ -5668,13 +5751,13 @@ def __call__(
56685751    AttnProcessorNPU ,
56695752    AttnProcessor2_0 ,
56705753    MochiVaeAttnProcessor2_0 ,
5754+     MochiAttnProcessor2_0 ,
56715755    StableAudioAttnProcessor2_0 ,
56725756    HunyuanAttnProcessor2_0 ,
56735757    FusedHunyuanAttnProcessor2_0 ,
56745758    PAGHunyuanAttnProcessor2_0 ,
56755759    PAGCFGHunyuanAttnProcessor2_0 ,
56765760    LuminaAttnProcessor2_0 ,
5677-     MochiAttnProcessor2_0 ,
56785761    FusedAttnProcessor2_0 ,
56795762    CustomDiffusionXFormersAttnProcessor ,
56805763    CustomDiffusionAttnProcessor2_0 ,
0 commit comments