1919import numpy as np
2020import torch
2121import torch .nn .functional as F
22- from einops import rearrange
2322from PIL import Image
2423from transformers import (
2524 BertModel ,
6160 >>> from diffusers.utils import export_to_video, load_video
6261
6362 >>> pipe = EasyAnimateControlPipeline.from_pretrained(
64- ... "alibaba-pai/EasyAnimateV5.1-12b-zh-Control", torch_dtype=torch.bfloat16
63+ ... "alibaba-pai/EasyAnimateV5.1-12b-zh-Control-diffusers ", torch_dtype=torch.bfloat16
6564 ... )
6665 >>> pipe.to("cuda")
6766
8483 >>> video = pipe(
8584 ... prompt,
8685 ... num_frames=num_frames,
87- ... negative_prompt="bad detailed ",
86+ ... negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text. ",
8887 ... height=sample_size[0],
8988 ... width=sample_size[1],
9089 ... control_video=input_video,
9392 ```
9493"""
9594
95+ def preprocess_image (image , sample_size ):
96+ """
97+ Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor.
98+ """
99+ if isinstance (image , torch .Tensor ):
100+ # If input is a tensor, assume it's in CHW format and resize using interpolation
101+ image = torch .nn .functional .interpolate (
102+ image .unsqueeze (0 ), size = sample_size , mode = "bilinear" , align_corners = False
103+ ).squeeze (0 )
104+ elif isinstance (image , Image .Image ):
105+ # If input is a PIL image, resize and convert to numpy array
106+ image = image .resize ((sample_size [1 ], sample_size [0 ]))
107+ image = np .array (image )
108+ elif isinstance (image , np .ndarray ):
109+ # If input is a numpy array, resize using PIL
110+ image = Image .fromarray (image ).resize ((sample_size [1 ], sample_size [0 ]))
111+ image = np .array (image )
112+ else :
113+ raise ValueError ("Unsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor." )
96114
97- def get_video_to_video_latent (
98- input_video_path , num_frames , sample_size , fps = None , validation_video_mask = None , ref_image = None
99- ):
100- if input_video_path is not None :
101- if isinstance (input_video_path , str ):
102- import cv2
103-
104- cap = cv2 .VideoCapture (input_video_path )
105- input_video = []
106-
107- original_fps = cap .get (cv2 .CAP_PROP_FPS )
108- frame_skip = 1 if fps is None else int (original_fps // fps )
109-
110- frame_count = 0
115+ # Convert to tensor if not already
116+ if not isinstance (image , torch .Tensor ):
117+ image = torch .from_numpy (image ).permute (2 , 0 , 1 ).float () / 255.0 # HWC -> CHW, normalize to [0, 1]
111118
112- while True :
113- ret , frame = cap .read ()
114- if not ret :
115- break
119+ return image
116120
117- if frame_count % frame_skip == 0 :
118- frame = cv2 .resize (frame , (sample_size [1 ], sample_size [0 ]))
119- input_video .append (cv2 .cvtColor (frame , cv2 .COLOR_BGR2RGB ))
120121
121- frame_count += 1
122+ def get_video_to_video_latent (
123+ input_video , num_frames , sample_size , validation_video_mask = None , ref_image = None
124+ ):
125+ if input_video is not None :
126+ # Convert each frame in the list to tensor
127+ input_video = [preprocess_image (frame , sample_size = sample_size ) for frame in input_video ]
122128
123- cap .release ()
124- else :
125- input_video = input_video_path
129+ # Stack all frames into a single tensor (F, C, H, W)
130+ input_video = torch .stack (input_video )[:num_frames ]
126131
127- input_video = torch . from_numpy ( np . array ( input_video ))[: num_frames ]
128- input_video = input_video .permute ([ 3 , 0 , 1 , 2 ] ).unsqueeze (0 ) / 255
132+ # Add batch dimension (B, F, C, H, W)
133+ input_video = input_video .permute (1 , 0 , 2 , 3 ).unsqueeze (0 )
129134
130135 if validation_video_mask is not None :
131- validation_video_mask = (
132- Image .open (validation_video_mask ).convert ("L" ).resize ((sample_size [1 ], sample_size [0 ]))
133- )
134- input_video_mask = np .where (np .array (validation_video_mask ) < 240 , 0 , 255 )
135-
136- input_video_mask = (
137- torch .from_numpy (np .array (input_video_mask ))
138- .unsqueeze (0 )
139- .unsqueeze (- 1 )
140- .permute ([3 , 0 , 1 , 2 ])
141- .unsqueeze (0 )
142- )
136+ # Handle mask input
137+ validation_video_mask = preprocess_image (validation_video_mask , size = sample_size )
138+ input_video_mask = torch .where (validation_video_mask < 240 / 255.0 , 0.0 , 255 )
139+
140+ # Adjust mask dimensions to match video
141+ input_video_mask = input_video_mask .unsqueeze (0 ).unsqueeze (- 1 ).permute ([3 , 0 , 1 , 2 ]).unsqueeze (0 )
143142 input_video_mask = torch .tile (input_video_mask , [1 , 1 , input_video .size ()[2 ], 1 , 1 ])
144143 input_video_mask = input_video_mask .to (input_video .device , input_video .dtype )
145144 else :
@@ -149,14 +148,12 @@ def get_video_to_video_latent(
149148 input_video , input_video_mask = None , None
150149
151150 if ref_image is not None :
152- if isinstance (ref_image , str ):
153- ref_image = Image .open (ref_image ).convert ("RGB" )
154- ref_image = ref_image .resize ((sample_size [1 ], sample_size [0 ]))
155- ref_image = torch .from_numpy (np .array (ref_image ))
156- ref_image = ref_image .unsqueeze (0 ).permute ([3 , 0 , 1 , 2 ]).unsqueeze (0 ) / 255
157- else :
158- ref_image = torch .from_numpy (np .array (ref_image ))
159- ref_image = ref_image .unsqueeze (0 ).permute ([3 , 0 , 1 , 2 ]).unsqueeze (0 ) / 255
151+ # Convert reference image to tensor
152+ ref_image = preprocess_image (ref_image , size = sample_size )
153+ ref_image = ref_image .permute (1 , 0 , 2 , 3 ).unsqueeze (0 ) # Add batch dimension (B, C, H, W)
154+ else :
155+ ref_image = None
156+
160157 return input_video , input_video_mask , ref_image
161158
162159
@@ -1025,12 +1022,12 @@ def __call__(
10251022 torch .cat ([control_video_latents ] * 2 ) if self .do_classifier_free_guidance else control_video_latents
10261023 ).to (device , dtype )
10271024 elif control_video is not None :
1028- num_frames = control_video .shape [ 2 ]
1025+ batch_size , channels , num_frames , height_video , width_video = control_video .shape
10291026 control_video = self .image_processor .preprocess (
1030- rearrange ( control_video , "b c f h w -> (b f) c h w" ), height = height , width = width
1027+ control_video . permute ( 0 , 2 , 1 , 3 , 4 ). reshape ( batch_size * num_frames , channels , height_video , width_video ), height = height , width = width
10311028 )
10321029 control_video = control_video .to (dtype = torch .float32 )
1033- control_video = rearrange ( control_video , "(b f) c h w -> b c f h w" , f = num_frames )
1030+ control_video = control_video . reshape ( batch_size , num_frames , channels , height , width ). permute ( 0 , 2 , 1 , 3 , 4 )
10341031 control_video_latents = self .prepare_control_latents (
10351032 None ,
10361033 control_video ,
@@ -1052,12 +1049,12 @@ def __call__(
10521049 ).to (device , dtype )
10531050
10541051 if ref_image is not None :
1055- num_frames = ref_image .shape [ 2 ]
1052+ batch_size , channels , num_frames , height_video , width_video = ref_image .shape
10561053 ref_image = self .image_processor .preprocess (
1057- rearrange ( ref_image , "b c f h w -> (b f) c h w" ), height = height , width = width
1054+ ref_image . permute ( 0 , 2 , 1 , 3 , 4 ). reshape ( batch_size * num_frames , channels , height_video , width_video ), height = height , width = width
10581055 )
10591056 ref_image = ref_image .to (dtype = torch .float32 )
1060- ref_image = rearrange ( ref_image , "(b f) c h w -> b c f h w" , f = num_frames )
1057+ ref_image = ref_image . reshape ( batch_size , num_frames , channels , height , width ). permute ( 0 , 2 , 1 , 3 , 4 )
10611058
10621059 ref_image_latentes = self .prepare_control_latents (
10631060 None ,
@@ -1092,30 +1089,6 @@ def __call__(
10921089 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
10931090 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
10941091
1095- # 7 create image_rotary_emb, style embedding & time ids
1096- grid_height = height // 8 // self .transformer .config .patch_size
1097- grid_width = width // 8 // self .transformer .config .patch_size
1098- if self .transformer .config .get ("time_position_encoding_type" , "2d_rope" ) == "3d_rope" :
1099- base_size_width = 720 // 8 // self .transformer .config .patch_size
1100- base_size_height = 480 // 8 // self .transformer .config .patch_size
1101-
1102- grid_crops_coords = get_resize_crop_region_for_grid (
1103- (grid_height , grid_width ), base_size_width , base_size_height
1104- )
1105- image_rotary_emb = get_3d_rotary_pos_embed (
1106- self .transformer .config .attention_head_dim ,
1107- grid_crops_coords ,
1108- grid_size = (grid_height , grid_width ),
1109- temporal_size = latents .size (2 ),
1110- use_real = True ,
1111- )
1112- else :
1113- base_size = 512 // 8 // self .transformer .config .patch_size
1114- grid_crops_coords = get_resize_crop_region_for_grid ((grid_height , grid_width ), base_size , base_size )
1115- image_rotary_emb = get_2d_rotary_pos_embed (
1116- self .transformer .config .attention_head_dim , grid_crops_coords , (grid_height , grid_width )
1117- )
1118-
11191092 if self .do_classifier_free_guidance :
11201093 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
11211094 prompt_attention_mask = torch .cat ([negative_prompt_attention_mask , prompt_attention_mask ])
@@ -1130,7 +1103,7 @@ def __call__(
11301103 prompt_embeds_2 = prompt_embeds_2 .to (device = device )
11311104 prompt_attention_mask_2 = prompt_attention_mask_2 .to (device = device )
11321105
1133- # 8 . Denoising loop
1106+ # 7 . Denoising loop
11341107 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
11351108 self ._num_timesteps = len (timesteps )
11361109 with self .progress_bar (total = num_inference_steps ) as progress_bar :
@@ -1153,7 +1126,6 @@ def __call__(
11531126 t_expand ,
11541127 encoder_hidden_states = prompt_embeds ,
11551128 encoder_hidden_states_t5 = prompt_embeds_2 ,
1156- image_rotary_emb = image_rotary_emb ,
11571129 control_latents = control_latents ,
11581130 return_dict = False ,
11591131 )[0 ]
0 commit comments