Skip to content

Commit 19fcc7d

Browse files
committed
delete process
1 parent 914f460 commit 19fcc7d

File tree

1 file changed

+0
-103
lines changed

1 file changed

+0
-103
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 0 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -3613,109 +3613,6 @@ def __call__(
36133613
return hidden_states
36143614

36153615

3616-
class EasyAnimateAttnProcessor2_0:
3617-
r"""
3618-
Attention processor used in EasyAnimate.
3619-
"""
3620-
3621-
def __init__(self):
3622-
if not hasattr(F, "scaled_dot_product_attention"):
3623-
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
3624-
3625-
def __call__(
3626-
self,
3627-
attn: Attention,
3628-
hidden_states: torch.Tensor,
3629-
encoder_hidden_states: torch.Tensor,
3630-
attention_mask: Optional[torch.Tensor] = None,
3631-
image_rotary_emb: Optional[torch.Tensor] = None,
3632-
attn2: Attention = None,
3633-
) -> torch.Tensor:
3634-
text_seq_length = encoder_hidden_states.size(1)
3635-
3636-
batch_size, sequence_length, _ = (
3637-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
3638-
)
3639-
3640-
if attention_mask is not None:
3641-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
3642-
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
3643-
3644-
if attn2 is None:
3645-
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
3646-
3647-
query = attn.to_q(hidden_states)
3648-
key = attn.to_k(hidden_states)
3649-
value = attn.to_v(hidden_states)
3650-
3651-
inner_dim = key.shape[-1]
3652-
head_dim = inner_dim // attn.heads
3653-
3654-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3655-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3656-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3657-
3658-
if attn.norm_q is not None:
3659-
query = attn.norm_q(query)
3660-
if attn.norm_k is not None:
3661-
key = attn.norm_k(key)
3662-
3663-
if attn2 is not None:
3664-
query_txt = attn2.to_q(encoder_hidden_states)
3665-
key_txt = attn2.to_k(encoder_hidden_states)
3666-
value_txt = attn2.to_v(encoder_hidden_states)
3667-
3668-
inner_dim = key_txt.shape[-1]
3669-
head_dim = inner_dim // attn.heads
3670-
3671-
query_txt = query_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3672-
key_txt = key_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3673-
value_txt = value_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3674-
3675-
if attn2.norm_q is not None:
3676-
query_txt = attn2.norm_q(query_txt)
3677-
if attn2.norm_k is not None:
3678-
key_txt = attn2.norm_k(key_txt)
3679-
3680-
query = torch.cat([query_txt, query], dim=2)
3681-
key = torch.cat([key_txt, key], dim=2)
3682-
value = torch.cat([value_txt, value], dim=2)
3683-
3684-
# Apply RoPE if needed
3685-
if image_rotary_emb is not None:
3686-
from .embeddings import apply_rotary_emb
3687-
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
3688-
if not attn.is_cross_attention:
3689-
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
3690-
3691-
hidden_states = F.scaled_dot_product_attention(
3692-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3693-
)
3694-
3695-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
3696-
3697-
if attn2 is None:
3698-
# linear proj
3699-
hidden_states = attn.to_out[0](hidden_states)
3700-
# dropout
3701-
hidden_states = attn.to_out[1](hidden_states)
3702-
3703-
encoder_hidden_states, hidden_states = hidden_states.split(
3704-
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
3705-
)
3706-
else:
3707-
encoder_hidden_states, hidden_states = hidden_states.split(
3708-
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
3709-
)
3710-
# linear proj
3711-
hidden_states = attn.to_out[0](hidden_states)
3712-
encoder_hidden_states = attn2.to_out[0](encoder_hidden_states)
3713-
# dropout
3714-
hidden_states = attn.to_out[1](hidden_states)
3715-
encoder_hidden_states = attn2.to_out[1](encoder_hidden_states)
3716-
return hidden_states, encoder_hidden_states
3717-
3718-
37193616
class StableAudioAttnProcessor2_0:
37203617
r"""
37213618
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is

0 commit comments

Comments
 (0)