@@ -3613,109 +3613,6 @@ def __call__(
36133613 return hidden_states
36143614
36153615
3616- class EasyAnimateAttnProcessor2_0 :
3617- r"""
3618- Attention processor used in EasyAnimate.
3619- """
3620-
3621- def __init__ (self ):
3622- if not hasattr (F , "scaled_dot_product_attention" ):
3623- raise ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
3624-
3625- def __call__ (
3626- self ,
3627- attn : Attention ,
3628- hidden_states : torch .Tensor ,
3629- encoder_hidden_states : torch .Tensor ,
3630- attention_mask : Optional [torch .Tensor ] = None ,
3631- image_rotary_emb : Optional [torch .Tensor ] = None ,
3632- attn2 : Attention = None ,
3633- ) -> torch .Tensor :
3634- text_seq_length = encoder_hidden_states .size (1 )
3635-
3636- batch_size , sequence_length , _ = (
3637- hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
3638- )
3639-
3640- if attention_mask is not None :
3641- attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
3642- attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
3643-
3644- if attn2 is None :
3645- hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
3646-
3647- query = attn .to_q (hidden_states )
3648- key = attn .to_k (hidden_states )
3649- value = attn .to_v (hidden_states )
3650-
3651- inner_dim = key .shape [- 1 ]
3652- head_dim = inner_dim // attn .heads
3653-
3654- query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3655- key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3656- value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3657-
3658- if attn .norm_q is not None :
3659- query = attn .norm_q (query )
3660- if attn .norm_k is not None :
3661- key = attn .norm_k (key )
3662-
3663- if attn2 is not None :
3664- query_txt = attn2 .to_q (encoder_hidden_states )
3665- key_txt = attn2 .to_k (encoder_hidden_states )
3666- value_txt = attn2 .to_v (encoder_hidden_states )
3667-
3668- inner_dim = key_txt .shape [- 1 ]
3669- head_dim = inner_dim // attn .heads
3670-
3671- query_txt = query_txt .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3672- key_txt = key_txt .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3673- value_txt = value_txt .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3674-
3675- if attn2 .norm_q is not None :
3676- query_txt = attn2 .norm_q (query_txt )
3677- if attn2 .norm_k is not None :
3678- key_txt = attn2 .norm_k (key_txt )
3679-
3680- query = torch .cat ([query_txt , query ], dim = 2 )
3681- key = torch .cat ([key_txt , key ], dim = 2 )
3682- value = torch .cat ([value_txt , value ], dim = 2 )
3683-
3684- # Apply RoPE if needed
3685- if image_rotary_emb is not None :
3686- from .embeddings import apply_rotary_emb
3687- query [:, :, text_seq_length :] = apply_rotary_emb (query [:, :, text_seq_length :], image_rotary_emb )
3688- if not attn .is_cross_attention :
3689- key [:, :, text_seq_length :] = apply_rotary_emb (key [:, :, text_seq_length :], image_rotary_emb )
3690-
3691- hidden_states = F .scaled_dot_product_attention (
3692- query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
3693- )
3694-
3695- hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
3696-
3697- if attn2 is None :
3698- # linear proj
3699- hidden_states = attn .to_out [0 ](hidden_states )
3700- # dropout
3701- hidden_states = attn .to_out [1 ](hidden_states )
3702-
3703- encoder_hidden_states , hidden_states = hidden_states .split (
3704- [text_seq_length , hidden_states .size (1 ) - text_seq_length ], dim = 1
3705- )
3706- else :
3707- encoder_hidden_states , hidden_states = hidden_states .split (
3708- [text_seq_length , hidden_states .size (1 ) - text_seq_length ], dim = 1
3709- )
3710- # linear proj
3711- hidden_states = attn .to_out [0 ](hidden_states )
3712- encoder_hidden_states = attn2 .to_out [0 ](encoder_hidden_states )
3713- # dropout
3714- hidden_states = attn .to_out [1 ](hidden_states )
3715- encoder_hidden_states = attn2 .to_out [1 ](encoder_hidden_states )
3716- return hidden_states , encoder_hidden_states
3717-
3718-
37193616class StableAudioAttnProcessor2_0 :
37203617 r"""
37213618 Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
0 commit comments