Skip to content

Commit 96ccfb5

Browse files
committed
Merge branch 'main' of https://github.com/bubbliiiing/diffusers into easyanimate
2 parents 9e8a249 + 2e1b4f5 commit 96ccfb5

File tree

4 files changed

+188
-244
lines changed

4 files changed

+188
-244
lines changed

src/diffusers/models/transformers/transformer_easyanimate.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
import torch
1919
import torch.nn.functional as F
2020
from torch import nn
21+
from typing import Any, Dict, List, Optional, Tuple, Union
2122

2223
from ...configuration_utils import ConfigMixin, register_to_config
2324
from ...utils import logging
2425
from ...utils.torch_utils import maybe_allow_in_graph
2526
from ..attention import Attention, FeedForward
26-
from ..embeddings import TimestepEmbedding, Timesteps
27+
from ..embeddings import TimestepEmbedding, Timesteps, get_3d_rotary_pos_embed
2728
from ..modeling_outputs import Transformer2DModelOutput
2829
from ..modeling_utils import ModelMixin
2930
from ..normalization import AdaLayerNorm, FP32LayerNorm, RMSNorm
@@ -67,6 +68,50 @@ def forward(
6768
return hidden_states, encoder_hidden_states, gate, enc_gate
6869

6970

71+
class EasyAnimateRotaryPosEmbed(nn.Module):
72+
def __init__(self, patch_size: int, rope_dim: List[int]) -> None:
73+
super().__init__()
74+
75+
self.patch_size = patch_size
76+
self.rope_dim = rope_dim
77+
78+
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
79+
tw = tgt_width
80+
th = tgt_height
81+
h, w = src
82+
r = h / w
83+
if r > (th / tw):
84+
resize_height = th
85+
resize_width = int(round(th / h * w))
86+
else:
87+
resize_width = tw
88+
resize_height = int(round(tw / w * h))
89+
90+
crop_top = int(round((th - resize_height) / 2.0))
91+
crop_left = int(round((tw - resize_width) / 2.0))
92+
93+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
94+
95+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
96+
bs, c, num_frames, grid_height, grid_width = hidden_states.size()
97+
grid_height = grid_height // self.patch_size
98+
grid_width = grid_width // self.patch_size
99+
base_size_width = 90 // self.patch_size
100+
base_size_height = 60 // self.patch_size
101+
102+
grid_crops_coords = self.get_resize_crop_region_for_grid(
103+
(grid_height, grid_width), base_size_width, base_size_height
104+
)
105+
image_rotary_emb = get_3d_rotary_pos_embed(
106+
self.rope_dim,
107+
grid_crops_coords,
108+
grid_size=(grid_height, grid_width),
109+
temporal_size=hidden_states.size(2),
110+
use_real=True,
111+
)
112+
return image_rotary_emb
113+
114+
70115
class EasyAnimateAttnProcessor2_0:
71116
r"""
72117
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
@@ -361,6 +406,7 @@ def __init__(
361406
# 1. Timestep embedding
362407
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
363408
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
409+
self.rope_embedding = EasyAnimateRotaryPosEmbed(patch_size, attention_head_dim)
364410

365411
# 2. Patch embedding
366412
self.proj = nn.Conv2d(
@@ -422,7 +468,6 @@ def forward(
422468
timestep_cond: Optional[torch.Tensor] = None,
423469
encoder_hidden_states: Optional[torch.Tensor] = None,
424470
encoder_hidden_states_t5: Optional[torch.Tensor] = None,
425-
image_rotary_emb: Optional[torch.Tensor] = None,
426471
inpaint_latents: Optional[torch.Tensor] = None,
427472
control_latents: Optional[torch.Tensor] = None,
428473
return_dict: bool = True,
@@ -435,6 +480,7 @@ def forward(
435480
# 1. Time embedding
436481
temb = self.time_proj(timestep).to(dtype=hidden_states.dtype)
437482
temb = self.time_embedding(temb, timestep_cond)
483+
image_rotary_emb = self.rope_embedding(hidden_states)
438484

439485
# 2. Patch embedding
440486
if inpaint_latents is not None:

src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py

100644100755
Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
5757
>>> # Models: "alibaba-pai/EasyAnimateV5.1-12b-zh"
5858
>>> pipe = EasyAnimatePipeline.from_pretrained(
59-
... "alibaba-pai/EasyAnimateV5.1-7b-zh", torch_dtype=torch.float16
59+
... "alibaba-pai/EasyAnimateV5.1-7b-zh-diffusers", torch_dtype=torch.float16
6060
... ).to("cuda")
6161
>>> prompt = (
6262
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
@@ -877,30 +877,6 @@ def __call__(
877877
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
878878
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
879879

880-
# 7 create image_rotary_emb, style embedding & time ids
881-
grid_height = height // 8 // self.transformer.config.patch_size
882-
grid_width = width // 8 // self.transformer.config.patch_size
883-
if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
884-
base_size_width = 720 // 8 // self.transformer.config.patch_size
885-
base_size_height = 480 // 8 // self.transformer.config.patch_size
886-
887-
grid_crops_coords = get_resize_crop_region_for_grid(
888-
(grid_height, grid_width), base_size_width, base_size_height
889-
)
890-
image_rotary_emb = get_3d_rotary_pos_embed(
891-
self.transformer.config.attention_head_dim,
892-
grid_crops_coords,
893-
grid_size=(grid_height, grid_width),
894-
temporal_size=latents.size(2),
895-
use_real=True,
896-
)
897-
else:
898-
base_size = 512 // 8 // self.transformer.config.patch_size
899-
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size, base_size)
900-
image_rotary_emb = get_2d_rotary_pos_embed(
901-
self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
902-
)
903-
904880
if self.do_classifier_free_guidance:
905881
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
906882
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
@@ -915,7 +891,7 @@ def __call__(
915891
prompt_embeds_2 = prompt_embeds_2.to(device=device)
916892
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
917893

918-
# 8. Denoising loop
894+
# 7. Denoising loop
919895
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
920896
self._num_timesteps = len(timesteps)
921897
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -939,7 +915,6 @@ def __call__(
939915
t_expand,
940916
encoder_hidden_states=prompt_embeds,
941917
encoder_hidden_states_t5=prompt_embeds_2,
942-
image_rotary_emb=image_rotary_emb,
943918
return_dict=False,
944919
)[0]
945920

src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py

100644100755
Lines changed: 54 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import numpy as np
2020
import torch
2121
import torch.nn.functional as F
22-
from einops import rearrange
2322
from PIL import Image
2423
from transformers import (
2524
BertModel,
@@ -61,7 +60,7 @@
6160
>>> from diffusers.utils import export_to_video, load_video
6261
6362
>>> pipe = EasyAnimateControlPipeline.from_pretrained(
64-
... "alibaba-pai/EasyAnimateV5.1-12b-zh-Control", torch_dtype=torch.bfloat16
63+
... "alibaba-pai/EasyAnimateV5.1-12b-zh-Control-diffusers", torch_dtype=torch.bfloat16
6564
... )
6665
>>> pipe.to("cuda")
6766
@@ -84,7 +83,7 @@
8483
>>> video = pipe(
8584
... prompt,
8685
... num_frames=num_frames,
87-
... negative_prompt="bad detailed",
86+
... negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text.",
8887
... height=sample_size[0],
8988
... width=sample_size[1],
9089
... control_video=input_video,
@@ -93,53 +92,53 @@
9392
```
9493
"""
9594

95+
def preprocess_image(image, sample_size):
96+
"""
97+
Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor.
98+
"""
99+
if isinstance(image, torch.Tensor):
100+
# If input is a tensor, assume it's in CHW format and resize using interpolation
101+
image = torch.nn.functional.interpolate(
102+
image.unsqueeze(0), size=sample_size, mode="bilinear", align_corners=False
103+
).squeeze(0)
104+
elif isinstance(image, Image.Image):
105+
# If input is a PIL image, resize and convert to numpy array
106+
image = image.resize((sample_size[1], sample_size[0]))
107+
image = np.array(image)
108+
elif isinstance(image, np.ndarray):
109+
# If input is a numpy array, resize using PIL
110+
image = Image.fromarray(image).resize((sample_size[1], sample_size[0]))
111+
image = np.array(image)
112+
else:
113+
raise ValueError("Unsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor.")
96114

97-
def get_video_to_video_latent(
98-
input_video_path, num_frames, sample_size, fps=None, validation_video_mask=None, ref_image=None
99-
):
100-
if input_video_path is not None:
101-
if isinstance(input_video_path, str):
102-
import cv2
103-
104-
cap = cv2.VideoCapture(input_video_path)
105-
input_video = []
106-
107-
original_fps = cap.get(cv2.CAP_PROP_FPS)
108-
frame_skip = 1 if fps is None else int(original_fps // fps)
109-
110-
frame_count = 0
115+
# Convert to tensor if not already
116+
if not isinstance(image, torch.Tensor):
117+
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 # HWC -> CHW, normalize to [0, 1]
111118

112-
while True:
113-
ret, frame = cap.read()
114-
if not ret:
115-
break
119+
return image
116120

117-
if frame_count % frame_skip == 0:
118-
frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
119-
input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
120121

121-
frame_count += 1
122+
def get_video_to_video_latent(
123+
input_video, num_frames, sample_size, validation_video_mask=None, ref_image=None
124+
):
125+
if input_video is not None:
126+
# Convert each frame in the list to tensor
127+
input_video = [preprocess_image(frame, sample_size=sample_size) for frame in input_video]
122128

123-
cap.release()
124-
else:
125-
input_video = input_video_path
129+
# Stack all frames into a single tensor (F, C, H, W)
130+
input_video = torch.stack(input_video)[:num_frames]
126131

127-
input_video = torch.from_numpy(np.array(input_video))[:num_frames]
128-
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
132+
# Add batch dimension (B, F, C, H, W)
133+
input_video = input_video.permute(1, 0, 2, 3).unsqueeze(0)
129134

130135
if validation_video_mask is not None:
131-
validation_video_mask = (
132-
Image.open(validation_video_mask).convert("L").resize((sample_size[1], sample_size[0]))
133-
)
134-
input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
135-
136-
input_video_mask = (
137-
torch.from_numpy(np.array(input_video_mask))
138-
.unsqueeze(0)
139-
.unsqueeze(-1)
140-
.permute([3, 0, 1, 2])
141-
.unsqueeze(0)
142-
)
136+
# Handle mask input
137+
validation_video_mask = preprocess_image(validation_video_mask, size=sample_size)
138+
input_video_mask = torch.where(validation_video_mask < 240 / 255.0, 0.0, 255)
139+
140+
# Adjust mask dimensions to match video
141+
input_video_mask = input_video_mask.unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
143142
input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
144143
input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
145144
else:
@@ -149,14 +148,12 @@ def get_video_to_video_latent(
149148
input_video, input_video_mask = None, None
150149

151150
if ref_image is not None:
152-
if isinstance(ref_image, str):
153-
ref_image = Image.open(ref_image).convert("RGB")
154-
ref_image = ref_image.resize((sample_size[1], sample_size[0]))
155-
ref_image = torch.from_numpy(np.array(ref_image))
156-
ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
157-
else:
158-
ref_image = torch.from_numpy(np.array(ref_image))
159-
ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
151+
# Convert reference image to tensor
152+
ref_image = preprocess_image(ref_image, size=sample_size)
153+
ref_image = ref_image.permute(1, 0, 2, 3).unsqueeze(0) # Add batch dimension (B, C, H, W)
154+
else:
155+
ref_image = None
156+
160157
return input_video, input_video_mask, ref_image
161158

162159

@@ -1025,12 +1022,12 @@ def __call__(
10251022
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
10261023
).to(device, dtype)
10271024
elif control_video is not None:
1028-
num_frames = control_video.shape[2]
1025+
batch_size, channels, num_frames, height_video, width_video = control_video.shape
10291026
control_video = self.image_processor.preprocess(
1030-
rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width
1027+
control_video.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height_video, width_video), height=height, width=width
10311028
)
10321029
control_video = control_video.to(dtype=torch.float32)
1033-
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=num_frames)
1030+
control_video = control_video.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
10341031
control_video_latents = self.prepare_control_latents(
10351032
None,
10361033
control_video,
@@ -1052,12 +1049,12 @@ def __call__(
10521049
).to(device, dtype)
10531050

10541051
if ref_image is not None:
1055-
num_frames = ref_image.shape[2]
1052+
batch_size, channels, num_frames, height_video, width_video = ref_image.shape
10561053
ref_image = self.image_processor.preprocess(
1057-
rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width
1054+
ref_image.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height_video, width_video), height=height, width=width
10581055
)
10591056
ref_image = ref_image.to(dtype=torch.float32)
1060-
ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=num_frames)
1057+
ref_image = ref_image.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
10611058

10621059
ref_image_latentes = self.prepare_control_latents(
10631060
None,
@@ -1092,30 +1089,6 @@ def __call__(
10921089
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
10931090
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
10941091

1095-
# 7 create image_rotary_emb, style embedding & time ids
1096-
grid_height = height // 8 // self.transformer.config.patch_size
1097-
grid_width = width // 8 // self.transformer.config.patch_size
1098-
if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
1099-
base_size_width = 720 // 8 // self.transformer.config.patch_size
1100-
base_size_height = 480 // 8 // self.transformer.config.patch_size
1101-
1102-
grid_crops_coords = get_resize_crop_region_for_grid(
1103-
(grid_height, grid_width), base_size_width, base_size_height
1104-
)
1105-
image_rotary_emb = get_3d_rotary_pos_embed(
1106-
self.transformer.config.attention_head_dim,
1107-
grid_crops_coords,
1108-
grid_size=(grid_height, grid_width),
1109-
temporal_size=latents.size(2),
1110-
use_real=True,
1111-
)
1112-
else:
1113-
base_size = 512 // 8 // self.transformer.config.patch_size
1114-
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size, base_size)
1115-
image_rotary_emb = get_2d_rotary_pos_embed(
1116-
self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
1117-
)
1118-
11191092
if self.do_classifier_free_guidance:
11201093
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
11211094
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
@@ -1130,7 +1103,7 @@ def __call__(
11301103
prompt_embeds_2 = prompt_embeds_2.to(device=device)
11311104
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
11321105

1133-
# 8. Denoising loop
1106+
# 7. Denoising loop
11341107
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
11351108
self._num_timesteps = len(timesteps)
11361109
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -1153,7 +1126,6 @@ def __call__(
11531126
t_expand,
11541127
encoder_hidden_states=prompt_embeds,
11551128
encoder_hidden_states_t5=prompt_embeds_2,
1156-
image_rotary_emb=image_rotary_emb,
11571129
control_latents=control_latents,
11581130
return_dict=False,
11591131
)[0]

0 commit comments

Comments
 (0)