diff --git a/generate_video_df.py b/generate_video_df.py index 7b73cd2..675ea7d 100644 --- a/generate_video_df.py +++ b/generate_video_df.py @@ -6,11 +6,12 @@ import imageio import torch -from diffusers.utils import load_image +from diffusers.utils import load_image, load_video from skyreels_v2_infer import DiffusionForcingPipeline from skyreels_v2_infer.modules import download_model from skyreels_v2_infer.pipelines import PromptEnhancer +from skyreels_v2_infer.pipelines import resizecrop if __name__ == "__main__": @@ -20,6 +21,7 @@ parser.add_argument("--resolution", type=str, choices=["540P", "720P"]) parser.add_argument("--num_frames", type=int, default=97) parser.add_argument("--image", type=str, default=None) + parser.add_argument("--video", type=str, default=None) parser.add_argument("--ar_step", type=int, default=0) parser.add_argument("--causal_attention", action="store_true") parser.add_argument("--causal_block_size", type=int, default=1) @@ -35,6 +37,7 @@ parser.add_argument("--seed", type=int, default=None) parser.add_argument( "--prompt", + nargs="+", type=str, default="A woman in a leather jacket and sunglasses riding a vintage motorcycle through a desert highway at sunset, her hair blowing wildly in the wind as the motorcycle kicks up dust, with the golden sun casting long shadows across the barren landscape.", ) @@ -82,7 +85,32 @@ guidance_scale = args.guidance_scale shift = args.shift - image = load_image(args.image).convert("RGB") if args.image else None + if args.image: + args.image = load_image(args.image) + image_width, image_height = args.image.size + if image_height > image_width: + height, width = width, height + args.image = resizecrop(args.image, height, width) + image = args.image.convert("RGB") if args.image else None + + video = [] + pre_video_length = 17 + if args.overlap_history is not None: + pre_video_length = args.overlap_history + if args.video: + args.video = load_video(args.video) + arg_width = width + arg_height = height + for img in args.video: + image_width, image_height = img.size + if image_height > image_width: + height, width = arg_width, arg_height + img = resizecrop(img, height, width) + video.append(img.convert("RGB").resize((width, height))) + video = video[-pre_video_length:] + else: + video = None + negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" save_dir = os.path.join("result", args.outdir) @@ -141,11 +169,16 @@ print(f"prompt:{prompt_input}") print(f"guidance_scale:{guidance_scale}") + output_path = "" + current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) + video_out_file = f"{args.prompt[0][:100].replace('/','')}_{args.seed}_{current_time}.mp4" + with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad(): video_frames = pipe( prompt=prompt_input, negative_prompt=negative_prompt, image=image, + video=video, height=height, width=width, num_frames=num_frames, @@ -159,10 +192,11 @@ ar_step=args.ar_step, causal_block_size=args.causal_block_size, fps=fps, + local_rank=local_rank, + save_dir=save_dir, + video_out_file=video_out_file, )[0] if local_rank == 0: - current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) - video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4" output_path = os.path.join(save_dir, video_out_file) imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"]) diff --git a/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py b/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py index b27f1d8..890ea5a 100644 --- a/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py +++ b/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py @@ -5,6 +5,7 @@ from typing import Tuple from typing import Union +import imageio import numpy as np import torch from diffusers.image_processor import PipelineImageInput @@ -95,6 +96,26 @@ def encode_image( predix_video_latent_length = prefix_video[0].shape[1] return prefix_video, predix_video_latent_length + def encode_video( + self, video: List[PipelineImageInput], height: int, width: int, num_frames: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # prefix_video + prefix_video = np.array(video).transpose(3, 0, 1, 2) + prefix_video = torch.tensor(prefix_video) # .to(image_embeds.dtype).unsqueeze(1) + if prefix_video.dtype == torch.uint8: + prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0 + prefix_video = prefix_video.to(self.device) + prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] + print(prefix_video[0].shape) + causal_block_size = self.transformer.num_frame_per_block + if prefix_video[0].shape[1] % causal_block_size != 0: + truncate_len = prefix_video[0].shape[1] % causal_block_size + print("the length of prefix video is truncated for the casual block size alignment.") + prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] + predix_video_latent_length = prefix_video[0].shape[1] + return prefix_video, predix_video_latent_length + def prepare_latents( self, shape: Tuple[int], @@ -183,9 +204,10 @@ def generate_timestep_matrix( @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], + prompt, negative_prompt: Union[str, List[str]] = "", image: PipelineImageInput = None, + video: List[PipelineImageInput] = None, height: int = 480, width: int = 832, num_frames: int = 97, @@ -199,6 +221,9 @@ def __call__( ar_step: int = 5, causal_block_size: int = None, fps: int = 24, + local_rank: int = 0, + save_dir: str = "", + video_out_file: str = "", ): latent_height = height // 8 latent_width = width // 8 @@ -211,9 +236,18 @@ def __call__( predix_video_latent_length = 0 if image: prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames) + elif video: + prefix_video, predix_video_latent_length = self.encode_video(video, height, width, num_frames) self.text_encoder.to(self.device) - prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype) + prompt_embeds_list = [] + if type(prompt) is list: + for prompt_iter in prompt: + prompt_embeds_list.append(self.text_encoder.encode(prompt_iter).to(self.transformer.dtype)) + else: + prompt_embeds_list.append(self.text_encoder.encode(prompt).to(self.transformer.dtype)) + prompt_embeds = prompt_embeds_list[0] + if self.do_classifier_free_guidance: negative_prompt_embeds = self.text_encoder.encode(negative_prompt).to(self.transformer.dtype) if self.offload: @@ -316,8 +350,11 @@ def __call__( n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 print(f"n_iter:{n_iter}") output_video = None - for i in range(n_iter): - if output_video is not None: # i !=0 + for i_n_iter in range(n_iter): + if type(prompt) is list: + if len(prompt) > i_n_iter: + prompt_embeds = prompt_embeds_list[i_n_iter] + if output_video is not None: # i_n_iter !=0 prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device) prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] if prefix_video[0].shape[1] % causal_block_size != 0: @@ -325,13 +362,13 @@ def __call__( print("the length of prefix video is truncated for the casual block size alignment.") prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] predix_video_latent_length = prefix_video[0].shape[1] - finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames + finished_frame_num = i_n_iter * (base_num_frames - overlap_history_frames) + overlap_history_frames left_frame_num = latent_length - finished_frame_num base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) if ar_step > 0 and self.transformer.enable_teacache: num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step self.transformer.num_steps = num_steps - else: # i == 0 + else: # i_n_iter == 0 base_num_frames_iter = base_num_frames latent_shape = [16, base_num_frames_iter, latent_height, latent_width] latents = self.prepare_latents( @@ -413,7 +450,18 @@ def __call__( self.transformer.cpu() torch.cuda.empty_cache() x0 = latents[0].unsqueeze(0) - videos = [self.vae.decode(x0)[0]] + mid_output_video = self.vae.decode(x0) + videos = [mid_output_video[0]] + if local_rank == 0: + mid_output_video = (mid_output_video / 2 + 0.5).clamp(0, 1) + mid_output_video = [video for video in mid_output_video] + mid_output_video = [video.permute(1, 2, 3, 0) * 255 for video in mid_output_video] + mid_output_video = [video.cpu().numpy().astype(np.uint8) for video in mid_output_video] + + mid_video_out_file = f"mid_{i_n_iter}_{video_out_file}" + mid_output_path = os.path.join(save_dir, mid_video_out_file) + imageio.mimwrite(mid_output_path, mid_output_video[0], fps=fps, quality=8, output_params=["-loglevel", "error"]) + if output_video is None: output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w else: