Skip to content

Commit 5609fc2

Browse files
committed
Update EasyAnimate V5.1
1 parent 95c5ce4 commit 5609fc2

17 files changed

+6149
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Initially taken from GitHub's Python gitignore file
2+
__*/
23

34
# Byte-compiled / optimized / DLL files
45
__pycache__/

src/diffusers/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
"AutoencoderKLCogVideoX",
8787
"AutoencoderKLHunyuanVideo",
8888
"AutoencoderKLLTXVideo",
89+
"AutoencoderKLMagvit",
8990
"AutoencoderKLMochi",
9091
"AutoencoderKLTemporalDecoder",
9192
"AutoencoderOobleck",
@@ -97,6 +98,7 @@
9798
"ControlNetUnionModel",
9899
"ControlNetXSAdapter",
99100
"DiTTransformer2DModel",
101+
"EasyAnimateTransformer3DModel",
100102
"FluxControlNetModel",
101103
"FluxMultiControlNetModel",
102104
"FluxTransformer2DModel",
@@ -276,6 +278,9 @@
276278
"CogVideoXVideoToVideoPipeline",
277279
"CogView3PlusPipeline",
278280
"CycleDiffusionPipeline",
281+
"EasyAnimatePipeline",
282+
"EasyAnimateInpaintPipeline",
283+
"EasyAnimateControlPipeline",
279284
"FluxControlImg2ImgPipeline",
280285
"FluxControlInpaintPipeline",
281286
"FluxControlNetImg2ImgPipeline",
@@ -596,6 +601,7 @@
596601
AutoencoderKLCogVideoX,
597602
AutoencoderKLHunyuanVideo,
598603
AutoencoderKLLTXVideo,
604+
AutoencoderKLMagvit,
599605
AutoencoderKLMochi,
600606
AutoencoderKLTemporalDecoder,
601607
AutoencoderOobleck,
@@ -607,6 +613,7 @@
607613
ControlNetUnionModel,
608614
ControlNetXSAdapter,
609615
DiTTransformer2DModel,
616+
EasyAnimateTransformer3DModel,
610617
FluxControlNetModel,
611618
FluxMultiControlNetModel,
612619
FluxTransformer2DModel,
@@ -765,6 +772,9 @@
765772
CogVideoXVideoToVideoPipeline,
766773
CogView3PlusPipeline,
767774
CycleDiffusionPipeline,
775+
EasyAnimatePipeline,
776+
EasyAnimateInpaintPipeline,
777+
EasyAnimateControlPipeline,
768778
FluxControlImg2ImgPipeline,
769779
FluxControlInpaintPipeline,
770780
FluxControlNetImg2ImgPipeline,

src/diffusers/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
3232
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
3333
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
34+
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
3435
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
3536
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
3637
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
@@ -54,6 +55,7 @@
5455
_import_structure["modeling_utils"] = ["ModelMixin"]
5556
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
5657
_import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
58+
_import_structure["transformers.easyanimate_transformer_3d"] = ["EasyAnimateTransformer3DModel"]
5759
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
5860
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
5961
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
@@ -101,6 +103,7 @@
101103
AutoencoderKLCogVideoX,
102104
AutoencoderKLHunyuanVideo,
103105
AutoencoderKLLTXVideo,
106+
AutoencoderKLMagvit,
104107
AutoencoderKLMochi,
105108
AutoencoderKLTemporalDecoder,
106109
AutoencoderOobleck,
@@ -131,6 +134,7 @@
131134
CogView3PlusTransformer2DModel,
132135
DiTTransformer2DModel,
133136
DualTransformer2DModel,
137+
EasyAnimateTransformer3DModel,
134138
FluxTransformer2DModel,
135139
HunyuanDiT2DModel,
136140
HunyuanVideoTransformer3DModel,

src/diffusers/models/attention_processor.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3507,6 +3507,109 @@ def __call__(
35073507
return hidden_states
35083508

35093509

3510+
class EasyAnimateAttnProcessor2_0:
3511+
r"""
3512+
Attention processor used in EasyAnimate.
3513+
"""
3514+
3515+
def __init__(self):
3516+
if not hasattr(F, "scaled_dot_product_attention"):
3517+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
3518+
3519+
def __call__(
3520+
self,
3521+
attn: Attention,
3522+
hidden_states: torch.Tensor,
3523+
encoder_hidden_states: torch.Tensor,
3524+
attention_mask: Optional[torch.Tensor] = None,
3525+
image_rotary_emb: Optional[torch.Tensor] = None,
3526+
attn2: Attention = None,
3527+
) -> torch.Tensor:
3528+
text_seq_length = encoder_hidden_states.size(1)
3529+
3530+
batch_size, sequence_length, _ = (
3531+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
3532+
)
3533+
3534+
if attention_mask is not None:
3535+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
3536+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
3537+
3538+
if attn2 is None:
3539+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
3540+
3541+
query = attn.to_q(hidden_states)
3542+
key = attn.to_k(hidden_states)
3543+
value = attn.to_v(hidden_states)
3544+
3545+
inner_dim = key.shape[-1]
3546+
head_dim = inner_dim // attn.heads
3547+
3548+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3549+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3550+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3551+
3552+
if attn.norm_q is not None:
3553+
query = attn.norm_q(query)
3554+
if attn.norm_k is not None:
3555+
key = attn.norm_k(key)
3556+
3557+
if attn2 is not None:
3558+
query_txt = attn2.to_q(encoder_hidden_states)
3559+
key_txt = attn2.to_k(encoder_hidden_states)
3560+
value_txt = attn2.to_v(encoder_hidden_states)
3561+
3562+
inner_dim = key_txt.shape[-1]
3563+
head_dim = inner_dim // attn.heads
3564+
3565+
query_txt = query_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3566+
key_txt = key_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3567+
value_txt = value_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3568+
3569+
if attn2.norm_q is not None:
3570+
query_txt = attn2.norm_q(query_txt)
3571+
if attn2.norm_k is not None:
3572+
key_txt = attn2.norm_k(key_txt)
3573+
3574+
query = torch.cat([query_txt, query], dim=2)
3575+
key = torch.cat([key_txt, key], dim=2)
3576+
value = torch.cat([value_txt, value], dim=2)
3577+
3578+
# Apply RoPE if needed
3579+
if image_rotary_emb is not None:
3580+
from .embeddings import apply_rotary_emb
3581+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
3582+
if not attn.is_cross_attention:
3583+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
3584+
3585+
hidden_states = F.scaled_dot_product_attention(
3586+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3587+
)
3588+
3589+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
3590+
3591+
if attn2 is None:
3592+
# linear proj
3593+
hidden_states = attn.to_out[0](hidden_states)
3594+
# dropout
3595+
hidden_states = attn.to_out[1](hidden_states)
3596+
3597+
encoder_hidden_states, hidden_states = hidden_states.split(
3598+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
3599+
)
3600+
else:
3601+
encoder_hidden_states, hidden_states = hidden_states.split(
3602+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
3603+
)
3604+
# linear proj
3605+
hidden_states = attn.to_out[0](hidden_states)
3606+
encoder_hidden_states = attn2.to_out[0](encoder_hidden_states)
3607+
# dropout
3608+
hidden_states = attn.to_out[1](hidden_states)
3609+
encoder_hidden_states = attn2.to_out[1](encoder_hidden_states)
3610+
return hidden_states, encoder_hidden_states
3611+
3612+
35103613
class StableAudioAttnProcessor2_0:
35113614
r"""
35123615
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is

src/diffusers/models/autoencoders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
66
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
77
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
8+
from .autoencoder_kl_magvit import AutoencoderKLMagvit
89
from .autoencoder_kl_mochi import AutoencoderKLMochi
910
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
1011
from .autoencoder_oobleck import AutoencoderOobleck

0 commit comments

Comments
 (0)