@@ -136,6 +136,103 @@ def __call__(
136136
137137 return hidden_states , encoder_hidden_states
138138
139+ class FusedHunyuanVideoAttnProcessor2_0 :
140+ def __init__ (self ):
141+ if not hasattr (F , "scaled_dot_product_attention" ):
142+ raise ImportError (
143+ "FusedHunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
144+ )
145+
146+ def __call__ (
147+ self ,
148+ attn : Attention ,
149+ hidden_states : torch .Tensor ,
150+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
151+ attention_mask : Optional [torch .Tensor ] = None ,
152+ image_rotary_emb : Optional [torch .Tensor ] = None ,
153+ ) -> torch .Tensor :
154+ if attn .add_q_proj is None and encoder_hidden_states is not None :
155+ hidden_states = torch .cat ([hidden_states , encoder_hidden_states ], dim = 1 )
156+
157+ # 1. QKV projections
158+ qkv = attn .to_qkv (hidden_states )
159+ split_size = qkv .shape [- 1 ] // 3
160+ query , key , value = torch .split (qkv , split_size , dim = - 1 )
161+
162+ query = query .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
163+ key = key .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
164+ value = value .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
165+
166+ # 2. QK normalization
167+ if attn .norm_q is not None :
168+ query = attn .norm_q (query )
169+ if attn .norm_k is not None :
170+ key = attn .norm_k (key )
171+
172+ # 3. Rotational positional embeddings applied to latent stream
173+ if image_rotary_emb is not None :
174+ from ..embeddings import apply_rotary_emb
175+
176+ if attn .add_q_proj is None and encoder_hidden_states is not None :
177+ query = torch .cat (
178+ [
179+ apply_rotary_emb (query [:, :, : - encoder_hidden_states .shape [1 ]], image_rotary_emb ),
180+ query [:, :, - encoder_hidden_states .shape [1 ] :],
181+ ],
182+ dim = 2 ,
183+ )
184+ key = torch .cat (
185+ [
186+ apply_rotary_emb (key [:, :, : - encoder_hidden_states .shape [1 ]], image_rotary_emb ),
187+ key [:, :, - encoder_hidden_states .shape [1 ] :],
188+ ],
189+ dim = 2 ,
190+ )
191+ else :
192+ query = apply_rotary_emb (query , image_rotary_emb )
193+ key = apply_rotary_emb (key , image_rotary_emb )
194+
195+ # 4. Encoder condition QKV projection and normalization
196+ if attn .add_q_proj is not None and encoder_hidden_states is not None :
197+ encoder_query = attn .add_q_proj (encoder_hidden_states )
198+ encoder_key = attn .add_k_proj (encoder_hidden_states )
199+ encoder_value = attn .add_v_proj (encoder_hidden_states )
200+
201+ encoder_query = encoder_query .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
202+ encoder_key = encoder_key .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
203+ encoder_value = encoder_value .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
204+
205+ if attn .norm_added_q is not None :
206+ encoder_query = attn .norm_added_q (encoder_query )
207+ if attn .norm_added_k is not None :
208+ encoder_key = attn .norm_added_k (encoder_key )
209+
210+ query = torch .cat ([query , encoder_query ], dim = 2 )
211+ key = torch .cat ([key , encoder_key ], dim = 2 )
212+ value = torch .cat ([value , encoder_value ], dim = 2 )
213+
214+ # 5. Attention
215+ hidden_states = F .scaled_dot_product_attention (
216+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
217+ )
218+ hidden_states = hidden_states .transpose (1 , 2 ).flatten (2 , 3 )
219+ hidden_states = hidden_states .to (query .dtype )
220+
221+ # 6. Output projection
222+ if encoder_hidden_states is not None :
223+ hidden_states , encoder_hidden_states = (
224+ hidden_states [:, : - encoder_hidden_states .shape [1 ]],
225+ hidden_states [:, - encoder_hidden_states .shape [1 ] :],
226+ )
227+
228+ if getattr (attn , "to_out" , None ) is not None :
229+ hidden_states = attn .to_out [0 ](hidden_states )
230+ hidden_states = attn .to_out [1 ](hidden_states )
231+
232+ if getattr (attn , "to_add_out" , None ) is not None :
233+ encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
234+
235+ return hidden_states , encoder_hidden_states
139236
140237class HunyuanVideoPatchEmbed (nn .Module ):
141238 def __init__ (
@@ -214,6 +311,10 @@ def forward(
214311 )
215312
216313 gate_msa , gate_mlp = self .norm_out (temb )
314+ # QKV fusion fix
315+ if isinstance (attn_output , tuple ):
316+ attn_output = attn_output [0 ]
317+
217318 hidden_states = hidden_states + attn_output * gate_msa
218319
219320 ff_output = self .ff (self .norm2 (hidden_states ))
@@ -604,6 +705,46 @@ def __init__(
604705
605706 self .gradient_checkpointing = False
606707
708+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanVideoAttnProcessor2_0
709+ def fuse_qkv_projections (self ):
710+ """
711+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
712+ are fused. For cross-attention modules, key and value projection matrices are fused.
713+
714+ <Tip warning={true}>
715+
716+ This API is 🧪 experimental.
717+
718+ </Tip>
719+ """
720+ self .original_attn_processors = None
721+
722+ for _ , attn_processor in self .attn_processors .items ():
723+ if "Added" in str (attn_processor .__class__ .__name__ ):
724+ raise ValueError ("`fuse_qkv_projections()` is not supported for models having added KV projections." )
725+
726+ self .original_attn_processors = self .attn_processors
727+
728+ for module in self .modules ():
729+ if isinstance (module , Attention ):
730+ module .fuse_projections (fuse = True )
731+
732+ self .set_attn_processor (FusedHunyuanVideoAttnProcessor2_0 ())
733+
734+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
735+ def unfuse_qkv_projections (self ):
736+ """Disables the fused QKV projection if enabled.
737+
738+ <Tip warning={true}>
739+
740+ This API is 🧪 experimental.
741+
742+ </Tip>
743+
744+ """
745+ if self .original_attn_processors is not None :
746+ self .set_attn_processor (self .original_attn_processors )
747+
607748 @property
608749 # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
609750 def attn_processors (self ) -> Dict [str , AttentionProcessor ]:
0 commit comments