@@ -73,8 +73,7 @@ class EasyAnimateAttnProcessor2_0:
7373 used in the EasyAnimateTransformer3DModel model.
7474 """
7575
76- def __init__ (self , attn2 = None ):
77- self .attn2 = attn2
76+ def __init__ (self ):
7877 if not hasattr (F , "scaled_dot_product_attention" ):
7978 raise ImportError (
8079 "EasyAnimateAttnProcessor2_0 requires PyTorch 2.0 or above. To use it, please install PyTorch 2.0."
@@ -84,11 +83,11 @@ def __call__(
8483 self ,
8584 attn : Attention ,
8685 hidden_states : torch .Tensor ,
87- encoder_hidden_states : Optional [ torch .Tensor ] = None ,
86+ encoder_hidden_states : torch .Tensor ,
8887 attention_mask : Optional [torch .Tensor ] = None ,
8988 image_rotary_emb : Optional [torch .Tensor ] = None ,
9089 ) -> torch .Tensor :
91- if self . attn2 is None and encoder_hidden_states is not None :
90+ if attn . add_q_proj is None and encoder_hidden_states is not None :
9291 hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
9392
9493 # 1. QKV projections
@@ -107,19 +106,19 @@ def __call__(
107106 key = attn .norm_k (key )
108107
109108 # 3. Encoder condition QKV projection and normalization
110- if self . attn2 . to_q is not None and encoder_hidden_states is not None :
111- encoder_query = self . attn2 . to_q (encoder_hidden_states )
112- encoder_key = self . attn2 . to_k (encoder_hidden_states )
113- encoder_value = self . attn2 . to_v (encoder_hidden_states )
109+ if attn . add_q_proj is not None and encoder_hidden_states is not None :
110+ encoder_query = attn . add_q_proj (encoder_hidden_states )
111+ encoder_key = attn . add_k_proj (encoder_hidden_states )
112+ encoder_value = attn . add_v_proj (encoder_hidden_states )
114113
115114 encoder_query = encoder_query .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
116115 encoder_key = encoder_key .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
117116 encoder_value = encoder_value .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
118117
119- if self . attn2 . norm_q is not None :
120- encoder_query = self . attn2 . norm_q (encoder_query )
121- if self . attn2 . norm_k is not None :
122- encoder_key = self . attn2 . norm_k (encoder_key )
118+ if attn . norm_added_q is not None :
119+ encoder_query = attn . norm_added_q (encoder_query )
120+ if attn . norm_added_k is not None :
121+ encoder_key = attn . norm_added_k (encoder_key )
123122
124123 query = torch .cat ([encoder_query , query ], dim = 2 )
125124 key = torch .cat ([encoder_key , key ], dim = 2 )
@@ -154,9 +153,8 @@ def __call__(
154153 hidden_states = attn .to_out [0 ](hidden_states )
155154 hidden_states = attn .to_out [1 ](hidden_states )
156155
157- if self .attn2 is not None and getattr (self .attn2 , "to_out" , None ) is not None :
158- encoder_hidden_states = self .attn2 .to_out [0 ](encoder_hidden_states )
159- encoder_hidden_states = self .attn2 .to_out [1 ](encoder_hidden_states )
156+ if getattr (attn , "to_add_out" , None ) is not None :
157+ encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
160158 else :
161159 if getattr (attn , "to_out" , None ) is not None :
162160 hidden_states = attn .to_out [0 ](hidden_states )
@@ -192,27 +190,17 @@ def __init__(
192190 time_embed_dim , dim , norm_elementwise_affine , norm_eps , norm_type = norm_type , bias = True
193191 )
194192
195- if is_mmdit_block :
196- self .attn2 = Attention (
197- query_dim = dim ,
198- dim_head = attention_head_dim ,
199- heads = num_attention_heads ,
200- qk_norm = "layer_norm" if qk_norm else None ,
201- eps = 1e-6 ,
202- bias = True ,
203- processor = EasyAnimateAttnProcessor2_0 (),
204- )
205- else :
206- self .attn2 = None
207-
208193 self .attn1 = Attention (
209194 query_dim = dim ,
210195 dim_head = attention_head_dim ,
211196 heads = num_attention_heads ,
212197 qk_norm = "layer_norm" if qk_norm else None ,
213198 eps = 1e-6 ,
214199 bias = True ,
215- processor = EasyAnimateAttnProcessor2_0 (self .attn2 ),
200+ added_proj_bias = True ,
201+ added_kv_proj_dim = dim if is_mmdit_block else None ,
202+ context_pre_only = False if is_mmdit_block else None ,
203+ processor = EasyAnimateAttnProcessor2_0 (),
216204 )
217205
218206 # FFN Part
0 commit comments