Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions generate_video_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

import imageio
import torch
from diffusers.utils import load_image
from diffusers.utils import load_image, load_video

from skyreels_v2_infer import DiffusionForcingPipeline
from skyreels_v2_infer.modules import download_model
from skyreels_v2_infer.pipelines import PromptEnhancer
from skyreels_v2_infer.pipelines import resizecrop

if __name__ == "__main__":

Expand All @@ -20,6 +21,7 @@
parser.add_argument("--resolution", type=str, choices=["540P", "720P"])
parser.add_argument("--num_frames", type=int, default=97)
parser.add_argument("--image", type=str, default=None)
parser.add_argument("--video", type=str, default=None)
parser.add_argument("--ar_step", type=int, default=0)
parser.add_argument("--causal_attention", action="store_true")
parser.add_argument("--causal_block_size", type=int, default=1)
Expand All @@ -35,6 +37,7 @@
parser.add_argument("--seed", type=int, default=None)
parser.add_argument(
"--prompt",
nargs="+",
type=str,
default="A woman in a leather jacket and sunglasses riding a vintage motorcycle through a desert highway at sunset, her hair blowing wildly in the wind as the motorcycle kicks up dust, with the golden sun casting long shadows across the barren landscape.",
)
Expand Down Expand Up @@ -82,7 +85,32 @@

guidance_scale = args.guidance_scale
shift = args.shift
image = load_image(args.image).convert("RGB") if args.image else None
if args.image:
args.image = load_image(args.image)
image_width, image_height = args.image.size
if image_height > image_width:
height, width = width, height
args.image = resizecrop(args.image, height, width)
image = args.image.convert("RGB") if args.image else None

video = []
pre_video_length = 17
if args.overlap_history is not None:
pre_video_length = args.overlap_history
if args.video:
args.video = load_video(args.video)
arg_width = width
arg_height = height
for img in args.video:
image_width, image_height = img.size
if image_height > image_width:
height, width = arg_width, arg_height
img = resizecrop(img, height, width)
video.append(img.convert("RGB").resize((width, height)))
video = video[-pre_video_length:]
else:
video = None

negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"

save_dir = os.path.join("result", args.outdir)
Expand Down Expand Up @@ -141,11 +169,16 @@
print(f"prompt:{prompt_input}")
print(f"guidance_scale:{guidance_scale}")

output_path = ""
current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
video_out_file = f"{args.prompt[0][:100].replace('/','')}_{args.seed}_{current_time}.mp4"

with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
video_frames = pipe(
prompt=prompt_input,
negative_prompt=negative_prompt,
image=image,
video=video,
height=height,
width=width,
num_frames=num_frames,
Expand All @@ -159,10 +192,11 @@
ar_step=args.ar_step,
causal_block_size=args.causal_block_size,
fps=fps,
local_rank=local_rank,
save_dir=save_dir,
video_out_file=video_out_file,
)[0]

if local_rank == 0:
current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4"
output_path = os.path.join(save_dir, video_out_file)
imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])
62 changes: 55 additions & 7 deletions skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple
from typing import Union

import imageio
import numpy as np
import torch
from diffusers.image_processor import PipelineImageInput
Expand Down Expand Up @@ -95,6 +96,26 @@ def encode_image(
predix_video_latent_length = prefix_video[0].shape[1]
return prefix_video, predix_video_latent_length

def encode_video(
self, video: List[PipelineImageInput], height: int, width: int, num_frames: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

# prefix_video
prefix_video = np.array(video).transpose(3, 0, 1, 2)
prefix_video = torch.tensor(prefix_video) # .to(image_embeds.dtype).unsqueeze(1)
if prefix_video.dtype == torch.uint8:
prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
prefix_video = prefix_video.to(self.device)
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
print(prefix_video[0].shape)
causal_block_size = self.transformer.num_frame_per_block
if prefix_video[0].shape[1] % causal_block_size != 0:
truncate_len = prefix_video[0].shape[1] % causal_block_size
print("the length of prefix video is truncated for the casual block size alignment.")
prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
predix_video_latent_length = prefix_video[0].shape[1]
return prefix_video, predix_video_latent_length

def prepare_latents(
self,
shape: Tuple[int],
Expand Down Expand Up @@ -183,9 +204,10 @@ def generate_timestep_matrix(
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
prompt,
negative_prompt: Union[str, List[str]] = "",
image: PipelineImageInput = None,
video: List[PipelineImageInput] = None,
height: int = 480,
width: int = 832,
num_frames: int = 97,
Expand All @@ -199,6 +221,9 @@ def __call__(
ar_step: int = 5,
causal_block_size: int = None,
fps: int = 24,
local_rank: int = 0,
save_dir: str = "",
video_out_file: str = "",
):
latent_height = height // 8
latent_width = width // 8
Expand All @@ -211,9 +236,18 @@ def __call__(
predix_video_latent_length = 0
if image:
prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames)
elif video:
prefix_video, predix_video_latent_length = self.encode_video(video, height, width, num_frames)

self.text_encoder.to(self.device)
prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype)
prompt_embeds_list = []
if type(prompt) is list:
for prompt_iter in prompt:
prompt_embeds_list.append(self.text_encoder.encode(prompt_iter).to(self.transformer.dtype))
else:
prompt_embeds_list.append(self.text_encoder.encode(prompt).to(self.transformer.dtype))
prompt_embeds = prompt_embeds_list[0]

if self.do_classifier_free_guidance:
negative_prompt_embeds = self.text_encoder.encode(negative_prompt).to(self.transformer.dtype)
if self.offload:
Expand Down Expand Up @@ -316,22 +350,25 @@ def __call__(
n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
print(f"n_iter:{n_iter}")
output_video = None
for i in range(n_iter):
if output_video is not None: # i !=0
for i_n_iter in range(n_iter):
if type(prompt) is list:
if len(prompt) > i_n_iter:
prompt_embeds = prompt_embeds_list[i_n_iter]
if output_video is not None: # i_n_iter !=0
prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device)
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
if prefix_video[0].shape[1] % causal_block_size != 0:
truncate_len = prefix_video[0].shape[1] % causal_block_size
print("the length of prefix video is truncated for the casual block size alignment.")
prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
predix_video_latent_length = prefix_video[0].shape[1]
finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
finished_frame_num = i_n_iter * (base_num_frames - overlap_history_frames) + overlap_history_frames
left_frame_num = latent_length - finished_frame_num
base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
if ar_step > 0 and self.transformer.enable_teacache:
num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step
self.transformer.num_steps = num_steps
else: # i == 0
else: # i_n_iter == 0
base_num_frames_iter = base_num_frames
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
latents = self.prepare_latents(
Expand Down Expand Up @@ -413,7 +450,18 @@ def __call__(
self.transformer.cpu()
torch.cuda.empty_cache()
x0 = latents[0].unsqueeze(0)
videos = [self.vae.decode(x0)[0]]
mid_output_video = self.vae.decode(x0)
videos = [mid_output_video[0]]
if local_rank == 0:
mid_output_video = (mid_output_video / 2 + 0.5).clamp(0, 1)
mid_output_video = [video for video in mid_output_video]
mid_output_video = [video.permute(1, 2, 3, 0) * 255 for video in mid_output_video]
mid_output_video = [video.cpu().numpy().astype(np.uint8) for video in mid_output_video]

mid_video_out_file = f"mid_{i_n_iter}_{video_out_file}"
mid_output_path = os.path.join(save_dir, mid_video_out_file)
imageio.mimwrite(mid_output_path, mid_output_video[0], fps=fps, quality=8, output_params=["-loglevel", "error"])

if output_video is None:
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
else:
Expand Down