From 08f8ca49ff76395aadddd4832e42dea4189b80ab Mon Sep 17 00:00:00 2001 From: chaojie Date: Tue, 22 Apr 2025 03:44:06 +0800 Subject: [PATCH 1/6] add prompt travel python3 generate_video_df.py --model_id ${model_id} --resolution 540P --ar_step 0 --base_num_frames 97 --num_frames 177 --overlap_history 17 --addnoise_condition 20 --offload --prompt 'A woman in a leather jacket and sunglasses riding a vintage motorcycle through a desert highway at sunset, her hair blowing wildly in the wind as the motorcycle kicks up dust, with the golden sun casting long shadows across the barren landscape.' 'A woman flies into space' --- generate_video_df.py | 3 ++- .../pipelines/diffusion_forcing_pipeline.py | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/generate_video_df.py b/generate_video_df.py index b1bbd9c..9d192da 100644 --- a/generate_video_df.py +++ b/generate_video_df.py @@ -35,6 +35,7 @@ parser.add_argument("--seed", type=int, default=-1) parser.add_argument( "--prompt", + nargs="+", type=str, default="A woman in a leather jacket and sunglasses riding a vintage motorcycle through a desert highway at sunset, her hair blowing wildly in the wind as the motorcycle kicks up dust, with the golden sun casting long shadows across the barren landscape.", ) @@ -143,6 +144,6 @@ if local_rank == 0: current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) - video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4" + video_out_file = f"{args.prompt[0][:100].replace('/','')}_{args.seed}_{current_time}.mp4" output_path = os.path.join(save_dir, video_out_file) imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"]) diff --git a/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py b/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py index c6f39ec..3d4d6ce 100644 --- a/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py +++ b/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py @@ -183,7 +183,7 @@ def generate_timestep_matrix( @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], + prompt, negative_prompt: Union[str, List[str]] = "", image: PipelineImageInput = None, height: int = 480, @@ -213,7 +213,14 @@ def __call__( prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames) self.text_encoder.to(self.device) - prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype) + prompt_embeds_list = [] + if type(prompt) is list: + for prompt_iter in prompt: + prompt_embeds_list.append(self.text_encoder.encode(prompt_iter).to(self.transformer.dtype)) + else: + prompt_embeds_list.append(self.text_encoder.encode(prompt).to(self.transformer.dtype)) + prompt_embeds = prompt_embeds_list[0] + if self.do_classifier_free_guidance: negative_prompt_embeds = self.text_encoder.encode(negative_prompt).to(self.transformer.dtype) if self.offload: @@ -317,6 +324,9 @@ def __call__( print(f"n_iter:{n_iter}") output_video = None for i in range(n_iter): + if type(prompt) is list: + if len(prompt) > i: + prompt_embeds = prompt_embeds_list[i] if output_video is not None: # i !=0 prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device) prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] From 1382b4cc7c753b24131071a5013e6d31dd1fb883 Mon Sep 17 00:00:00 2001 From: "fles@qq.com" Date: Tue, 22 Apr 2025 11:01:29 +0800 Subject: [PATCH 2/6] add --video reference --- generate_video_df.py | 6 ++++- .../pipelines/diffusion_forcing_pipeline.py | 26 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/generate_video_df.py b/generate_video_df.py index 9d192da..19a6b65 100644 --- a/generate_video_df.py +++ b/generate_video_df.py @@ -6,7 +6,7 @@ import imageio import torch -from diffusers.utils import load_image +from diffusers.utils import load_image, load_video from skyreels_v2_infer import DiffusionForcingPipeline from skyreels_v2_infer.modules import download_model @@ -20,6 +20,7 @@ parser.add_argument("--resolution", type=str, choices=["540P", "720P"]) parser.add_argument("--num_frames", type=int, default=97) parser.add_argument("--image", type=str, default=None) + parser.add_argument("--video", type=str, default=None) parser.add_argument("--ar_step", type=int, default=0) parser.add_argument("--causal_attention", action="store_true") parser.add_argument("--causal_block_size", type=int, default=1) @@ -74,6 +75,8 @@ guidance_scale = args.guidance_scale shift = args.shift image = load_image(args.image).convert("RGB") if args.image else None + video = load_video(args.video) if args.video else None + video = video[-17:] negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" save_dir = os.path.join("result", args.outdir) @@ -127,6 +130,7 @@ prompt=prompt_input, negative_prompt=negative_prompt, image=image, + video=video, height=height, width=width, num_frames=num_frames, diff --git a/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py b/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py index 3d4d6ce..e48be51 100644 --- a/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py +++ b/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py @@ -95,6 +95,29 @@ def encode_image( predix_video_latent_length = prefix_video[0].shape[1] return prefix_video, predix_video_latent_length + def encode_video( + self, video: List[PipelineImageInput], height: int, width: int, num_frames: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # prefix_video + prefix_video = [] + for image in video: + prefix_video.append(image.convert("RGB").resize((width, height))) + prefix_video = np.array(prefix_video).transpose(3, 0, 1, 2) + prefix_video = torch.tensor(prefix_video) # .to(image_embeds.dtype).unsqueeze(1) + if prefix_video.dtype == torch.uint8: + prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0 + prefix_video = prefix_video.to(self.device) + prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] + print(prefix_video[0].shape) + causal_block_size = self.transformer.num_frame_per_block + if prefix_video[0].shape[1] % causal_block_size != 0: + truncate_len = prefix_video[0].shape[1] % causal_block_size + print("the length of prefix video is truncated for the casual block size alignment.") + prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] + predix_video_latent_length = prefix_video[0].shape[1] + return prefix_video, predix_video_latent_length + def prepare_latents( self, shape: Tuple[int], @@ -186,6 +209,7 @@ def __call__( prompt, negative_prompt: Union[str, List[str]] = "", image: PipelineImageInput = None, + video: List[PipelineImageInput] = None, height: int = 480, width: int = 832, num_frames: int = 97, @@ -211,6 +235,8 @@ def __call__( predix_video_latent_length = 0 if image: prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames) + elif video: + prefix_video, predix_video_latent_length = self.encode_video(video, height, width, num_frames) self.text_encoder.to(self.device) prompt_embeds_list = [] From 38214637df960d06d986bf319cd75af5a99f8ab8 Mon Sep 17 00:00:00 2001 From: "fles@qq.com" Date: Tue, 22 Apr 2025 14:23:35 +0800 Subject: [PATCH 3/6] fix image size compatibility when Diffusion Forcing --- generate_video_df.py | 27 ++++++++++++++++--- .../pipelines/diffusion_forcing_pipeline.py | 5 +--- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/generate_video_df.py b/generate_video_df.py index 19a6b65..5cbf509 100644 --- a/generate_video_df.py +++ b/generate_video_df.py @@ -11,6 +11,7 @@ from skyreels_v2_infer import DiffusionForcingPipeline from skyreels_v2_infer.modules import download_model from skyreels_v2_infer.pipelines import PromptEnhancer +from skyreels_v2_infer.pipelines import resizecrop if __name__ == "__main__": @@ -74,9 +75,29 @@ guidance_scale = args.guidance_scale shift = args.shift - image = load_image(args.image).convert("RGB") if args.image else None - video = load_video(args.video) if args.video else None - video = video[-17:] + if args.image: + args.image = load_image(args.image) + image_width, image_height = args.image.size + if image_height > image_width: + height, width = width, height + args.image = resizecrop(args.image, height, width) + image = args.image.convert("RGB") if args.image else None + + video = [] + if args.video: + args.video = load_video(args.video) + arg_width = width + arg_height = height + for img in args.video: + image_width, image_height = img.size + if image_height > image_width: + height, width = arg_width, arg_height + img = resizecrop(img, height, width) + video.append(img.convert("RGB").resize((width, height))) + video = video[-17:] + else: + video = None + negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" save_dir = os.path.join("result", args.outdir) diff --git a/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py b/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py index e48be51..8cff0d5 100644 --- a/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py +++ b/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py @@ -100,10 +100,7 @@ def encode_video( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # prefix_video - prefix_video = [] - for image in video: - prefix_video.append(image.convert("RGB").resize((width, height))) - prefix_video = np.array(prefix_video).transpose(3, 0, 1, 2) + prefix_video = np.array(video).transpose(3, 0, 1, 2) prefix_video = torch.tensor(prefix_video) # .to(image_embeds.dtype).unsqueeze(1) if prefix_video.dtype == torch.uint8: prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0 From 2e1717147b2ade9469c0630975fed8b0acf8de9c Mon Sep 17 00:00:00 2001 From: "fles@qq.com" Date: Wed, 23 Apr 2025 16:44:28 +0800 Subject: [PATCH 4/6] add mid video output --- generate_video_df.py | 9 ++++++-- .../pipelines/diffusion_forcing_pipeline.py | 23 +++++++++++++++---- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/generate_video_df.py b/generate_video_df.py index b76d509..4b3e282 100644 --- a/generate_video_df.py +++ b/generate_video_df.py @@ -146,6 +146,10 @@ print(f"prompt:{prompt_input}") print(f"guidance_scale:{guidance_scale}") + output_path = "" + current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) + video_out_file = f"{args.prompt[0][:100].replace('/','')}_{args.seed}_{current_time}.mp4" + with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad(): video_frames = pipe( prompt=prompt_input, @@ -165,10 +169,11 @@ ar_step=args.ar_step, causal_block_size=args.causal_block_size, fps=fps, + local_rank=local_rank, + save_dir=save_dir, + video_out_file=video_out_file, )[0] if local_rank == 0: - current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) - video_out_file = f"{args.prompt[0][:100].replace('/','')}_{args.seed}_{current_time}.mp4" output_path = os.path.join(save_dir, video_out_file) imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"]) diff --git a/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py b/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py index 8cff0d5..34e6274 100644 --- a/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py +++ b/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py @@ -5,6 +5,7 @@ from typing import Tuple from typing import Union +import imageio import numpy as np import torch from diffusers.image_processor import PipelineImageInput @@ -220,6 +221,9 @@ def __call__( ar_step: int = 5, causal_block_size: int = None, fps: int = 24, + local_rank: int = 0, + save_dir: str = "", + video_out_file: str = "", ): latent_height = height // 8 latent_width = width // 8 @@ -346,10 +350,10 @@ def __call__( n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 print(f"n_iter:{n_iter}") output_video = None - for i in range(n_iter): + for i_n_iter in range(n_iter): if type(prompt) is list: - if len(prompt) > i: - prompt_embeds = prompt_embeds_list[i] + if len(prompt) > i_n_iter: + prompt_embeds = prompt_embeds_list[i_n_iter] if output_video is not None: # i !=0 prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device) prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] @@ -443,7 +447,18 @@ def __call__( self.transformer.cpu() torch.cuda.empty_cache() x0 = latents[0].unsqueeze(0) - videos = [self.vae.decode(x0)[0]] + mid_output_video = self.vae.decode(x0) + videos = [mid_output_video[0]] + if local_rank == 0: + mid_output_video = (mid_output_video / 2 + 0.5).clamp(0, 1) + mid_output_video = [video for video in mid_output_video] + mid_output_video = [video.permute(1, 2, 3, 0) * 255 for video in mid_output_video] + mid_output_video = [video.cpu().numpy().astype(np.uint8) for video in mid_output_video] + + mid_video_out_file = f"mid_{i_n_iter}_{video_out_file}" + mid_output_path = os.path.join(save_dir, mid_video_out_file) + imageio.mimwrite(mid_output_path, mid_output_video[0], fps=fps, quality=8, output_params=["-loglevel", "error"]) + if output_video is None: output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w else: From ae4796e3c26b7e86594d3261717ba4dc38a6614d Mon Sep 17 00:00:00 2001 From: "fles@qq.com" Date: Wed, 23 Apr 2025 19:11:05 +0800 Subject: [PATCH 5/6] fix bug --- skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py b/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py index 34e6274..6fd429b 100644 --- a/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py +++ b/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py @@ -354,7 +354,7 @@ def __call__( if type(prompt) is list: if len(prompt) > i_n_iter: prompt_embeds = prompt_embeds_list[i_n_iter] - if output_video is not None: # i !=0 + if output_video is not None: # i_n_iter !=0 prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device) prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] if prefix_video[0].shape[1] % causal_block_size != 0: @@ -362,10 +362,10 @@ def __call__( print("the length of prefix video is truncated for the casual block size alignment.") prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] predix_video_latent_length = prefix_video[0].shape[1] - finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames + finished_frame_num = i_n_iter * (base_num_frames - overlap_history_frames) + overlap_history_frames left_frame_num = latent_length - finished_frame_num base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) - else: # i == 0 + else: # i_n_iter == 0 base_num_frames_iter = base_num_frames latent_shape = [16, base_num_frames_iter, latent_height, latent_width] latents = self.prepare_latents( From 61b31fd4d651d142f256ff503b7579296f2546df Mon Sep 17 00:00:00 2001 From: "fles@qq.com" Date: Thu, 24 Apr 2025 14:17:11 +0800 Subject: [PATCH 6/6] remove hard code reference video overlap_history --- generate_video_df.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/generate_video_df.py b/generate_video_df.py index 4b3e282..e04c9ee 100644 --- a/generate_video_df.py +++ b/generate_video_df.py @@ -84,6 +84,9 @@ image = args.image.convert("RGB") if args.image else None video = [] + pre_video_length = 17 + if args.overlap_history is not None: + pre_video_length = args.overlap_history if args.video: args.video = load_video(args.video) arg_width = width @@ -94,7 +97,7 @@ height, width = arg_width, arg_height img = resizecrop(img, height, width) video.append(img.convert("RGB").resize((width, height))) - video = video[-17:] + video = video[-pre_video_length:] else: video = None