@@ -3507,6 +3507,109 @@ def __call__(
35073507 return hidden_states
35083508
35093509
3510+ class EasyAnimateAttnProcessor2_0 :
3511+ r"""
3512+ Attention processor used in EasyAnimate.
3513+ """
3514+
3515+ def __init__ (self ):
3516+ if not hasattr (F , "scaled_dot_product_attention" ):
3517+ raise ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
3518+
3519+ def __call__ (
3520+ self ,
3521+ attn : Attention ,
3522+ hidden_states : torch .Tensor ,
3523+ encoder_hidden_states : torch .Tensor ,
3524+ attention_mask : Optional [torch .Tensor ] = None ,
3525+ image_rotary_emb : Optional [torch .Tensor ] = None ,
3526+ attn2 : Attention = None ,
3527+ ) -> torch .Tensor :
3528+ text_seq_length = encoder_hidden_states .size (1 )
3529+
3530+ batch_size , sequence_length , _ = (
3531+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
3532+ )
3533+
3534+ if attention_mask is not None :
3535+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
3536+ attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
3537+
3538+ if attn2 is None :
3539+ hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
3540+
3541+ query = attn .to_q (hidden_states )
3542+ key = attn .to_k (hidden_states )
3543+ value = attn .to_v (hidden_states )
3544+
3545+ inner_dim = key .shape [- 1 ]
3546+ head_dim = inner_dim // attn .heads
3547+
3548+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3549+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3550+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3551+
3552+ if attn .norm_q is not None :
3553+ query = attn .norm_q (query )
3554+ if attn .norm_k is not None :
3555+ key = attn .norm_k (key )
3556+
3557+ if attn2 is not None :
3558+ query_txt = attn2 .to_q (encoder_hidden_states )
3559+ key_txt = attn2 .to_k (encoder_hidden_states )
3560+ value_txt = attn2 .to_v (encoder_hidden_states )
3561+
3562+ inner_dim = key_txt .shape [- 1 ]
3563+ head_dim = inner_dim // attn .heads
3564+
3565+ query_txt = query_txt .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3566+ key_txt = key_txt .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3567+ value_txt = value_txt .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3568+
3569+ if attn2 .norm_q is not None :
3570+ query_txt = attn2 .norm_q (query_txt )
3571+ if attn2 .norm_k is not None :
3572+ key_txt = attn2 .norm_k (key_txt )
3573+
3574+ query = torch .cat ([query_txt , query ], dim = 2 )
3575+ key = torch .cat ([key_txt , key ], dim = 2 )
3576+ value = torch .cat ([value_txt , value ], dim = 2 )
3577+
3578+ # Apply RoPE if needed
3579+ if image_rotary_emb is not None :
3580+ from .embeddings import apply_rotary_emb
3581+ query [:, :, text_seq_length :] = apply_rotary_emb (query [:, :, text_seq_length :], image_rotary_emb )
3582+ if not attn .is_cross_attention :
3583+ key [:, :, text_seq_length :] = apply_rotary_emb (key [:, :, text_seq_length :], image_rotary_emb )
3584+
3585+ hidden_states = F .scaled_dot_product_attention (
3586+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
3587+ )
3588+
3589+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
3590+
3591+ if attn2 is None :
3592+ # linear proj
3593+ hidden_states = attn .to_out [0 ](hidden_states )
3594+ # dropout
3595+ hidden_states = attn .to_out [1 ](hidden_states )
3596+
3597+ encoder_hidden_states , hidden_states = hidden_states .split (
3598+ [text_seq_length , hidden_states .size (1 ) - text_seq_length ], dim = 1
3599+ )
3600+ else :
3601+ encoder_hidden_states , hidden_states = hidden_states .split (
3602+ [text_seq_length , hidden_states .size (1 ) - text_seq_length ], dim = 1
3603+ )
3604+ # linear proj
3605+ hidden_states = attn .to_out [0 ](hidden_states )
3606+ encoder_hidden_states = attn2 .to_out [0 ](encoder_hidden_states )
3607+ # dropout
3608+ hidden_states = attn .to_out [1 ](hidden_states )
3609+ encoder_hidden_states = attn2 .to_out [1 ](encoder_hidden_states )
3610+ return hidden_states , encoder_hidden_states
3611+
3612+
35103613class StableAudioAttnProcessor2_0 :
35113614 r"""
35123615 Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
0 commit comments