diff --git a/README.md b/README.md
index dd2a88e..befa633 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,98 @@
+## Changes from pftq:
+- **Added seed synchronization code to allow random seed with multi-GPU** (https://github.com/SkyworkAI/SkyReels-V2/issues/24).
+- **Reduced 20-min+ load time on multi-GPU to ~8min** by fixing contention (all GPUs loading models at once). Indirectly also solved CPU RAM spike during multi-GPU (>200GB on 4 GPUs) (https://github.com/SkyworkAI/SkyReels-V2/issues/28).
+- **Fixed CuSolver error** that occasionally comes up in multi-GPU by presetting linear algebra library (https://github.com/SkyworkAI/SkyReels-V2/issues/37).
+- **Added batch_size** parameter to allow multiple videos to generate without reloading the model, which takes about 20 min on multi-gpu so this saves a lot of time.
+- **Added preserve_image_aspect_ratio** parameter to allow preserving original image aspect ratio.
+- Fixed DF script not resize-cropping the image (I2V script does it but DF is missing the code).
+- Exposed negative_prompt to allow that to be changed/overwritten.
+- Friendlier filenames with date, seed, cfg, steps, and other details in front.
+
+## Additional changes from chaojie's fork (https://github.com/SkyworkAI/SkyReels-V2/pull/12):
+- **Multiple prompts**, allow multiple text strings in the --prompt parameter to guide the video differently each chunk of base_num_frames.
+- **Video input** via --video parameter, allow continuing/extending from a video.
+- **Partially complete videos saved** as each chunk of base_num_frames completes. In combination with the --video parameter, this lets you effectively resume from a previous render as well as abort mid-render if the videos take a turn you don't like. Extremely useful for saving time and "watching" as the renders complete rather than committing the full time.
+
+Example prompts below. If you run into memory/vram issues, you can reduce the base_num_frames while still having the same higher number on num_frames. The point of the DF model is that now the whole video doesn't have to fit in VRAM and can be done in chunks.
+
+Multi-GPU with video input and prompt travel, batch of 10, preserving aspect ratio.
+Change --video "video.mp4" to --image "image.jpg" if you want to load a starting image instead.
+```
+model_id=Skywork/SkyReels-V2-DF-14B-540P
+gpu_count=2
+torchrun --nproc_per_node=${gpu_count} generate_video_df.py \
+ --model_id ${model_id} \
+ --resolution 540P \
+ --ar_step 0 \
+ --base_num_frames 97 \
+ --num_frames 289 \
+ --overlap_history 17 \
+ --inference_steps 50 \
+ --guidance_scale 6 \
+ --batch_size 10 \
+ --preserve_image_aspect_ratio \
+ --video "video.mp4" \
+ --prompt "The first thing he does" \
+ "The second thing he does." \
+ "The third thing he does." \
+ --negative_prompt "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" \
+ --addnoise_condition 20 \
+ --use_ret_steps \
+ --teacache_thresh 0.0 \
+ --use_usp \
+ --offload
+```
+
+Single GPU with video input and prompt travel, batch of 10, preserving aspect ratio.
+Change --video "video.mp4" to --image "image.jpg" if you want to load a starting image instead.
+```
+model_id=Skywork/SkyReels-V2-DF-14B-540P
+python3 generate_video_df.py \
+ --model_id ${model_id} \
+ --resolution 540P \
+ --ar_step 0 \
+ --base_num_frames 97 \
+ --num_frames 289 \
+ --overlap_history 17 \
+ --inference_steps 50 \
+ --guidance_scale 6 \
+ --batch_size 10 \
+ --preserve_image_aspect_ratio \
+ --video "video.mp4" \
+ --prompt "The first thing he does" \
+ "The second thing he does." \
+ "The third thing he does." \
+ --negative_prompt "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" \
+ --addnoise_condition 20 \
+ --use_ret_steps \
+ --teacache_thresh 0.0 \
+ --offload
+```
+
+Easy install instructions for those like me using Runpod for H100 and multi-gpu:
+```
+#create once on new pod
+export HF_HOME=/workspace/
+export TZ=America/Los_Angeles
+python -m venv venv
+git clone https://github.com/pftq/SkyReels-V2_Improvements
+mv SkyReels-V2_Improvements SkyReels-V2
+cd /workspace/SkyReels-V2
+source /workspace/venv/bin/activate
+pip install torch==2.5.1
+pip install --upgrade wheel setuptools
+pip install packaging
+pip install -r requirements.txt --no-build-isolation
+```
+```
+#always run at the start to use persisting drive
+export HF_HOME=/workspace/
+export TZ=America/Los_Angeles
+source /workspace/venv/bin/activate
+cd /workspace/SkyReels-V2
+```
+
+
diff --git a/generate_video.py b/generate_video.py
index 61e1394..2071c5b 100644
--- a/generate_video.py
+++ b/generate_video.py
@@ -8,6 +8,9 @@
import torch
from diffusers.utils import load_image
+from PIL import Image #20250422 pftq: Added for image resizing and cropping
+import numpy as np #20250422 pftq: Added for seed synchronization
+
from skyreels_v2_infer.modules import download_model
from skyreels_v2_infer.pipelines import Image2VideoPipeline
from skyreels_v2_infer.pipelines import PromptEnhancer
@@ -41,7 +44,12 @@
parser.add_argument("--use_usp", action="store_true")
parser.add_argument("--offload", action="store_true")
parser.add_argument("--fps", type=int, default=24)
- parser.add_argument("--seed", type=int, default=None)
+ parser.add_argument("--seed", type=int, default=-1)
+
+ parser.add_argument("--batch_size", type=int, default=1) # 20250422 pftq: Batch functionality to avoid reloading the model each video
+ parser.add_argument("--preserve_image_aspect_ratio", action="store_true") # 20250422 pftq: Avoid resizing
+ parser.add_argument("--negative_prompt", type=str, default="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards") # 20250422 pftq: expose negative prompt
+
parser.add_argument(
"--prompt",
type=str,
@@ -63,22 +71,9 @@
args.model_id = download_model(args.model_id)
print("model_id:", args.model_id)
- assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed"
- if args.seed is None:
- random.seed(time.time())
- args.seed = int(random.randrange(4294967294))
-
- if args.resolution == "540P":
- height = 544
- width = 960
- elif args.resolution == "720P":
- height = 720
- width = 1280
- else:
- raise ValueError(f"Invalid resolution: {args.resolution}")
+ #20250422 pftq: unneeded with seed synchronization code
+ #assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed"
- image = load_image(args.image).convert("RGB") if args.image else None
- negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
local_rank = 0
if args.use_usp:
assert not args.prompt_enhancer, "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
@@ -98,6 +93,52 @@
ulysses_degree=dist.get_world_size(),
)
+ if args.resolution == "540P":
+ height = 544
+ width = 960
+ elif args.resolution == "720P":
+ height = 720
+ width = 1280
+ else:
+ raise ValueError(f"Invalid resolution: {args.resolution}")
+
+ #image = load_image(args.image).convert("RGB") if args.image else None
+
+
+ #20250422 pftq: Add error handling for image loading, aspect ratio preservation
+ image = None
+ if args.image:
+ try:
+ image = load_image(args.image).convert("RGB")
+
+ # 20250422 pftq: option to preserve image aspect ratio
+ if args.preserve_image_aspect_ratio:
+ img_width, img_height = image.size
+ if img_height > img_width:
+ height, width = width, height
+ width = int(height / img_height * img_width)
+ else:
+ height = int(width / img_width * img_height)
+
+ divisibility=16
+ if width%divisibility!=0:
+ width = width - (width%divisibility)
+ if height%divisibility!=0:
+ height = height - (height%divisibility)
+
+ image = resizecrop(image, height, width)
+ else:
+ image_width, image_height = image.size
+ if image_height > image_width:
+ height, width = width, height
+ image = resizecrop(image, height, width)
+ except Exception as e:
+ raise ValueError(f"Failed to load or process image: {e}")
+
+ print(f"Rank {local_rank}: {width}x{height} | Image: "+str(image!=None))
+
+ negative_prompt = args.negative_prompt # 20250422 pftq: allow editable negative prompt
+
prompt_input = args.prompt
if args.prompt_enhancer and args.image is None:
print(f"init prompt enhancer")
@@ -108,6 +149,9 @@
gc.collect()
torch.cuda.empty_cache()
+ # 20250423 pftq: needs fixing, 20-min load times on multi-GPU caused by contention, DF already reduced down to 12 min roughly the same as single GPU.
+ print("Initializing pipe at "+time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()))
+ starttime = time.time()
if image is None:
assert "T2V" in args.model_id, f"check model_id:{args.model_id}"
print("init text2video pipeline")
@@ -120,11 +164,8 @@
pipe = Image2VideoPipeline(
model_path=args.model_id, dit_path=args.model_id, use_usp=args.use_usp, offload=args.offload
)
- args.image = load_image(args.image)
- image_width, image_height = args.image.size
- if image_height > image_width:
- height, width = width, height
- args.image = resizecrop(args.image, height, width)
+ totaltime = time.time()-starttime
+ print("Finished initializing pipe at "+time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())+" ("+str(int(totaltime))+" seconds)")
if args.teacache:
pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=args.inference_steps,
@@ -132,30 +173,78 @@
ckpt_dir=args.model_id)
- kwargs = {
- "prompt": prompt_input,
- "negative_prompt": negative_prompt,
- "num_frames": args.num_frames,
- "num_inference_steps": args.inference_steps,
- "guidance_scale": args.guidance_scale,
- "shift": args.shift,
- "generator": torch.Generator(device="cuda").manual_seed(args.seed),
- "height": height,
- "width": width,
- }
-
- if image is not None:
- kwargs["image"] = args.image.convert("RGB")
-
- save_dir = os.path.join("result", args.outdir)
- os.makedirs(save_dir, exist_ok=True)
-
- with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
- print(f"infer kwargs:{kwargs}")
- video_frames = pipe(**kwargs)[0]
-
- if local_rank == 0:
- current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
- video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4"
- output_path = os.path.join(save_dir, video_out_file)
- imageio.mimwrite(output_path, video_frames, fps=args.fps, quality=8, output_params=["-loglevel", "error"])
+
+ #20250422 pftq: Set preferred linear algebra backend to avoid cuSOLVER issues
+ torch.backends.cuda.preferred_linalg_library("default") # or try "magma" if available
+
+ for idx in range(args.batch_size): # 20250422 pftq: implemented --batch_size
+ if local_rank == 0:
+ print(f"Generating video {idx+1} of {args.batch_size}")
+
+ #20250422 pftq: Synchronize seed across all ranks
+ if args.use_usp:
+ try:
+ #20250422 pftq: Synchronize ranks before seed broadcasting
+ dist.barrier()
+
+ #20250422 pftq: Always broadcast seed to ensure consistency
+ if local_rank == 0:
+ if args.seed == -1 or idx > 0:
+ args.seed = int(random.randrange(4294967294))
+ seed_tensor = torch.tensor(args.seed, dtype=torch.int64, device="cuda")
+ dist.broadcast(seed_tensor, src=0)
+ args.seed = seed_tensor.item()
+
+ #20250422 pftq: Synchronize ranks after seed broadcasting
+ dist.barrier()
+ except Exception as e:
+ print(f"[Rank {local_rank}] Seed broadcasting error: {e}")
+ dist.destroy_process_group()
+ raise
+
+ else:
+ #20250422 pftq: Single GPU seed initialization
+ if args.seed == -1 or idx > 0:
+ args.seed = int(random.randrange(4294967294))
+
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ kwargs = {
+ "prompt": prompt_input,
+ "negative_prompt": negative_prompt,
+ "num_frames": args.num_frames,
+ "num_inference_steps": args.inference_steps,
+ "guidance_scale": args.guidance_scale,
+ "shift": args.shift,
+ "generator": torch.Generator(device="cuda").manual_seed(args.seed),
+ "height": height,
+ "width": width,
+ }
+
+ if image is not None:
+ #kwargs["image"] = load_image(args.image).convert("RGB")
+ # 20250422 pftq: redundant reloading of the image
+ kwargs["image"] = image
+
+ save_dir = os.path.join("result", args.outdir)
+ os.makedirs(save_dir, exist_ok=True)
+
+ with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
+ print(f"infer kwargs:{kwargs}")
+ video_frames = pipe(**kwargs)[0]
+
+ if local_rank == 0:
+ current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
+ #video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4"
+
+ # 20250422 pftq: more useful filename
+ gpucount = ""
+ if args.use_usp and dist.get_world_size():
+ gpucount = "_"+str(dist.get_world_size())+"xGPU"
+ video_out_file = f"{current_time}_skyreels2_{args.resolution}-{args.num_frames}f_cfg{args.guidance_scale}_steps{args.inference_steps}_seed{args.seed}{gpucount}_{args.prompt[:100].replace('/','')}_{idx}.mp4"
+
+ output_path = os.path.join(save_dir, video_out_file)
+ imageio.mimwrite(output_path, video_frames, fps=args.fps, quality=8, output_params=["-loglevel", "error"])
diff --git a/generate_video_df.py b/generate_video_df.py
index aa49782..0ef7ba8 100644
--- a/generate_video_df.py
+++ b/generate_video_df.py
@@ -3,10 +3,12 @@
import os
import random
import time
-
-import imageio
import torch
from diffusers.utils import load_image
+import imageio
+from PIL import Image #20250422 pftq: Added for image resizing and cropping
+import numpy as np #20250422 pftq: Added for seed synchronization
+from diffusers.utils import load_video # 20250425 chaojie prompt travel & video input
from skyreels_v2_infer import DiffusionForcingPipeline
from skyreels_v2_infer.modules import download_model
@@ -21,6 +23,7 @@
parser.add_argument("--resolution", type=str, choices=["540P", "720P"])
parser.add_argument("--num_frames", type=int, default=97)
parser.add_argument("--image", type=str, default=None)
+ parser.add_argument("--video", type=str, default=None) # 20250425 chaojie prompt travel & video input
parser.add_argument("--ar_step", type=int, default=0)
parser.add_argument("--causal_attention", action="store_true")
parser.add_argument("--causal_block_size", type=int, default=1)
@@ -33,9 +36,15 @@
parser.add_argument("--use_usp", action="store_true")
parser.add_argument("--offload", action="store_true")
parser.add_argument("--fps", type=int, default=24)
- parser.add_argument("--seed", type=int, default=None)
+ parser.add_argument("--seed", type=int, default=-1)
+
+ parser.add_argument("--batch_size", type=int, default=1) # 20250422 pftq: Batch functionality to avoid reloading the model each video
+ parser.add_argument("--preserve_image_aspect_ratio", action="store_true") # 20250422 pftq: Avoid resizing
+ parser.add_argument("--negative_prompt", type=str, default="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards") # 20250422 pftq: expose negative prompt
+
parser.add_argument(
"--prompt",
+ nargs="+", # 20250425 chaojie prompt travel & video input
type=str,
default="A woman in a leather jacket and sunglasses riding a vintage motorcycle through a desert highway at sunset, her hair blowing wildly in the wind as the motorcycle kicks up dust, with the golden sun casting long shadows across the barren landscape.",
)
@@ -57,10 +66,8 @@
args.model_id = download_model(args.model_id)
print("model_id:", args.model_id)
- assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed"
- if args.seed is None:
- random.seed(time.time())
- args.seed = int(random.randrange(4294967294))
+ #20250422 pftq: unneeded with seed synchronization code
+ #assert (args.use_usp and args.seed != -1) or (not args.use_usp), "usp mode requires a valid seed"
if args.resolution == "540P":
height = 544
@@ -77,14 +84,15 @@
if num_frames > args.base_num_frames:
assert (
args.overlap_history is not None
- ), 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.'
+ ), 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommended to set.'
if args.addnoise_condition > 60:
print(
- f'You have set "addnoise_condition" as {args.addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.'
+ f'You have set "addnoise_condition" as {args.addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommended to set 20.'
)
guidance_scale = args.guidance_scale
shift = args.shift
+ """
if args.image:
args.image = load_image(args.image)
image_width, image_height = args.image.size
@@ -92,8 +100,77 @@
height, width = width, height
args.image = resizecrop(args.image, height, width)
image = args.image.convert("RGB") if args.image else None
- negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+ """
+
+ # 20250425 chaojie prompt travel & video input
+ video = []
+ if args.video:
+ pre_video_length = 17
+ if args.overlap_history is not None:
+ pre_video_length = args.overlap_history
+ args.video = load_video(args.video)
+ arg_width = width
+ arg_height = height
+ for img in args.video:
+ # 20250422 pftq: option to preserve image aspect ratio
+ if args.preserve_image_aspect_ratio:
+ img_width, img_height = img.size
+ if img_height > img_width:
+ height, width = width, height
+ width = int(height / img_height * img_width)
+ else:
+ height = int(width / img_width * img_height)
+
+ divisibility=16
+ if width%divisibility!=0:
+ width = width - (width%divisibility)
+ if height%divisibility!=0:
+ height = height - (height%divisibility)
+
+ img = resizecrop(img, height, width)
+ else:
+ image_width, image_height = img.size
+ if image_height > image_width:
+ height, width = width, height
+ img = resizecrop(img, height, width)
+ video.append(img.convert("RGB").resize((width, height)))
+ video = video[-pre_video_length:]
+ else:
+ video = None
+ #20250422 pftq: Add error handling for image loading, aspect ratio preservation
+ image = None
+ if args.image and not args.video:
+ try:
+ image = load_image(args.image).convert("RGB")
+
+ # 20250422 pftq: option to preserve image aspect ratio
+ if args.preserve_image_aspect_ratio:
+ img_width, img_height = image.size
+ if img_height > img_width:
+ height, width = width, height
+ width = int(height / img_height * img_width)
+ else:
+ height = int(width / img_width * img_height)
+
+ divisibility=16
+ if width%divisibility!=0:
+ width = width - (width%divisibility)
+ if height%divisibility!=0:
+ height = height - (height%divisibility)
+
+ image = resizecrop(image, height, width)
+ else:
+ image_width, image_height = image.size
+ if image_height > image_width:
+ height, width = width, height
+ image = resizecrop(image, height, width)
+ except Exception as e:
+ raise ValueError(f"Failed to load or process image: {e}")
+
+ #negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+ negative_prompt = args.negative_prompt # 20250422 pftq: allow editable negative prompt
+
save_dir = os.path.join("result", args.outdir)
os.makedirs(save_dir, exist_ok=True)
local_rank = 0
@@ -127,6 +204,11 @@
gc.collect()
torch.cuda.empty_cache()
+ print(f"Rank {local_rank}: {width}x{height} | Image Input: "+str(image!=None) + " | Video Input: "+str(video!=None))
+
+ # 20250423 pftq: fixed 20-min load times on multi-GPU caused by contention, reduced down to 12 min roughly the same as single GPU.
+ print("Initializing pipe at "+time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()))
+ starttime = time.time()
pipe = DiffusionForcingPipeline(
args.model_id,
dit_path=args.model_id,
@@ -135,52 +217,104 @@
use_usp=args.use_usp,
offload=args.offload,
)
+ totaltime = time.time()-starttime
+ print("Finished initializing pipe at "+time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())+" ("+str(int(totaltime))+" seconds)")
if args.causal_attention:
pipe.transformer.set_ar_attention(args.causal_block_size)
-
+
if args.teacache:
if args.ar_step > 0:
- num_steps = (
- args.inference_steps
- + (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1) * args.ar_step
- )
- print("num_steps:", num_steps)
+ num_steps = args.inference_steps + (((args.base_num_frames - 1)//4 + 1) // args.causal_block_size - 1) * args.ar_step
+ print('num_steps:', num_steps)
else:
num_steps = args.inference_steps
- pipe.transformer.initialize_teacache(
- enable_teacache=True,
- num_steps=num_steps,
- teacache_thresh=args.teacache_thresh,
- use_ret_steps=args.use_ret_steps,
- ckpt_dir=args.model_id,
- )
+ pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=num_steps,
+ teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps,
+ ckpt_dir=args.model_id)
+
+ #20250422 pftq: Set preferred linear algebra backend to avoid cuSOLVER issues
+ torch.backends.cuda.preferred_linalg_library("default") # or try "magma" if available
+
+ for idx in range(args.batch_size): # 20250422 pftq: implemented --batch_size
+ if local_rank == 0:
+ print(f"prompt:{prompt_input}")
+ print(f"guidance_scale:{guidance_scale}")
+ print(f"Generating video {idx+1} of {args.batch_size}")
+
+ #20250422 pftq: Synchronize seed across all ranks
+ if args.use_usp:
+ try:
+ #20250422 pftq: Synchronize ranks before seed broadcasting
+ dist.barrier()
+
+ #20250422 pftq: Always broadcast seed to ensure consistency
+ if local_rank == 0:
+ if args.seed == -1 or idx > 0:
+ args.seed = int(random.randrange(4294967294))
+ seed_tensor = torch.tensor(args.seed, dtype=torch.int64, device="cuda")
+ dist.broadcast(seed_tensor, src=0)
+ args.seed = seed_tensor.item()
+
+ #20250422 pftq: Synchronize ranks after seed broadcasting
+ dist.barrier()
+ except Exception as e:
+ print(f"[Rank {local_rank}] Seed broadcasting error: {e}")
+ dist.destroy_process_group()
+ raise
+
+ else:
+ #20250422 pftq: Single GPU seed initialization
+ if args.seed == -1 or idx > 0:
+ args.seed = int(random.randrange(4294967294))
+
+ #20250422 pftq: Set seeds for reproducibility
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
- print(f"prompt:{prompt_input}")
- print(f"guidance_scale:{guidance_scale}")
-
- with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
- video_frames = pipe(
- prompt=prompt_input,
- negative_prompt=negative_prompt,
- image=image,
- height=height,
- width=width,
- num_frames=num_frames,
- num_inference_steps=args.inference_steps,
- shift=shift,
- guidance_scale=guidance_scale,
- generator=torch.Generator(device="cuda").manual_seed(args.seed),
- overlap_history=args.overlap_history,
- addnoise_condition=args.addnoise_condition,
- base_num_frames=args.base_num_frames,
- ar_step=args.ar_step,
- causal_block_size=args.causal_block_size,
- fps=fps,
- )[0]
-
- if local_rank == 0:
current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
- video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4"
- output_path = os.path.join(save_dir, video_out_file)
- imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])
+ #video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4"
+
+ # 20250422 pftq: more useful filename
+ gpucount = ""
+ if args.use_usp and dist.get_world_size():
+ gpucount = "_"+str(dist.get_world_size())+"xGPU"
+ prompt_summary = ""
+ if type(args.prompt) is list:
+ prompt_summary = args.prompt[0][:10].replace('/','')
+ else:
+ prompt_summary = args.prompt[:10].replace('/','')
+ video_out_file = f"{current_time}_skyreels2df_{args.resolution}-{args.num_frames}f_cfg{args.guidance_scale}_steps{args.inference_steps}_seed{args.seed}{gpucount}_{prompt_summary}_{idx}.mp4"
+
+ with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
+ video_frames = pipe(
+ prompt=prompt_input,
+ negative_prompt=negative_prompt,
+ image=image,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ num_inference_steps=args.inference_steps,
+ shift=shift,
+ guidance_scale=guidance_scale,
+ generator=torch.Generator(device="cuda").manual_seed(args.seed),
+ overlap_history=args.overlap_history,
+ addnoise_condition=args.addnoise_condition,
+ base_num_frames=args.base_num_frames,
+ ar_step=args.ar_step,
+ causal_block_size=args.causal_block_size,
+ fps=fps,
+
+ # 20250425 chaojie prompt travel & video input
+ video=video,
+ local_rank=local_rank,
+ save_dir=save_dir,
+ video_out_file=video_out_file,
+ )[0]
+
+ if local_rank == 0:
+ output_path = os.path.join(save_dir, video_out_file)
+ imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])
diff --git a/skyreels_v2_infer/modules/__init__.py b/skyreels_v2_infer/modules/__init__.py
index 5bc6afe..6f1c9be 100644
--- a/skyreels_v2_infer/modules/__init__.py
+++ b/skyreels_v2_infer/modules/__init__.py
@@ -27,18 +27,21 @@ def get_vae(model_path, device="cuda", weight_dtype=torch.float32) -> WanVAE:
return vae
-def get_transformer(model_path, device="cuda", weight_dtype=torch.bfloat16) -> WanModel:
+def get_transformer(model_path, device="cuda", weight_dtype=torch.bfloat16, skip_weights=False) -> WanModel:
+ # 20250423 pftq: Added skip_weights parameter to initialize empty model
config_path = os.path.join(model_path, "config.json")
transformer = WanModel.from_config(config_path).to(weight_dtype).to(device)
- for file in os.listdir(model_path):
- if file.endswith(".safetensors"):
- file_path = os.path.join(model_path, file)
- state_dict = load_file(file_path)
- transformer.load_state_dict(state_dict, strict=False)
- del state_dict
- gc.collect()
- torch.cuda.empty_cache()
+ if not skip_weights:
+ # 20250423 pftq: Only load weights if skip_weights=False
+ for file in os.listdir(model_path):
+ if file.endswith(".safetensors"):
+ file_path = os.path.join(model_path, file)
+ state_dict = load_file(file_path)
+ transformer.load_state_dict(state_dict, strict=False)
+ del state_dict
+ gc.collect()
+ torch.cuda.empty_cache()
transformer.requires_grad_(False)
transformer.eval()
@@ -47,10 +50,15 @@ def get_transformer(model_path, device="cuda", weight_dtype=torch.bfloat16) -> W
return transformer
-def get_text_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> T5EncoderModel:
+def get_text_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16, skip_weights=False) -> T5EncoderModel:
+ # 20250423 pftq: Added skip_weights and weights_only=True
t5_model = os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth")
tokenizer_path = os.path.join(model_path, "google", "umt5-xxl")
- text_encoder = T5EncoderModel(checkpoint_path=t5_model, tokenizer_path=tokenizer_path).to(device).to(weight_dtype)
+ text_encoder = T5EncoderModel(
+ checkpoint_path=t5_model if not skip_weights else None,
+ tokenizer_path=tokenizer_path,
+ weights_only=True
+ ).to(device).to(weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
gc.collect()
diff --git a/skyreels_v2_infer/modules/t5.py b/skyreels_v2_infer/modules/t5.py
index b882fe7..d49cd02 100644
--- a/skyreels_v2_infer/modules/t5.py
+++ b/skyreels_v2_infer/modules/t5.py
@@ -425,6 +425,7 @@ def __init__(
tokenizer_path=None,
text_len=512,
shard_fn=None,
+ weights_only=False, # 20250423 pftq: Added for torch.load
):
self.text_len = text_len
self.checkpoint_path = checkpoint_path
@@ -433,8 +434,12 @@ def __init__(
super().__init__()
# init model
model = umt5_xxl(encoder_only=True, return_tokenizer=False)
- logging.info(f"loading {checkpoint_path}")
- model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
+ # 20250423 pftq: Load weights only if checkpoint_path is provided
+ if checkpoint_path:
+ logging.info(f"loading {checkpoint_path}")
+ model.load_state_dict(
+ torch.load(checkpoint_path, map_location="cpu", weights_only=weights_only)
+ )
self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)
diff --git a/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py b/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py
index b27f1d8..6da8834 100644
--- a/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py
+++ b/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py
@@ -5,6 +5,7 @@
from typing import Tuple
from typing import Union
+import imageio # 20250425 chaojie prompt travel & video input
import numpy as np
import torch
from diffusers.image_processor import PipelineImageInput
@@ -51,14 +52,100 @@ def __init__(
device (str): Device to run on, defaults to 'cuda'
weight_dtype: Weight data type, defaults to torch.bfloat16
"""
+
+ # 20250423 pftq: Fixed 20-min multi-gpu load time by loading on Rank 0 first and broadcasting
+
+ import torch.distributed as dist # 20250423 pftq: Added for rank checking and broadcasting
+ self.device = device
+ self.offload = offload
load_device = "cpu" if offload else device
- self.transformer = get_transformer(dit_path, load_device, weight_dtype)
+
+ # 20250423 pftq: Check rank and distributed mode
+ if use_usp:
+ if not dist.is_initialized():
+ raise RuntimeError("Distributed environment must be initialized with dist.init_process_group before using use_usp=True")
+ local_rank = dist.get_rank()
+ else:
+ local_rank = 0
+
+ print(f"[Rank {local_rank}] Initializing pipeline components...")
+
vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
- self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
- self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
+ # 20250423 pftq: Load normally on single gpu
+ if not use_usp:
+ print(f"[Rank {local_rank}] Loading transformer to {load_device}...")
+ self.transformer = get_transformer(dit_path, load_device, weight_dtype, skip_weights=False)
+ print(f"[Rank {local_rank}] Loading text encoder to {load_device}...")
+ self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype, skip_weights=False)
+ print(f"[Rank {local_rank}] Loading VAE...")
+ self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
+
+ # 20250423 pftq: Broadcast transformer from rank 0
+ if use_usp:
+ broadcast_device = "cpu" # tested to be more stable to start with cpu broadcast even if you have an H100
+ if local_rank == 0:
+ print(f"[Rank {local_rank}] Loading transformer to {broadcast_device}...")
+ self.transformer = get_transformer(dit_path, broadcast_device, weight_dtype, skip_weights=False)
+ transformer_state_dict = self.transformer.state_dict()
+ else:
+ print(f"[Rank {local_rank}] Skipping transformer load...")
+ self.transformer = get_transformer(dit_path, broadcast_device, weight_dtype, skip_weights=True)
+ transformer_state_dict = None
+ dist.barrier() # Ensure rank 0 loads transformer and text encoder
+ transformer_list = [transformer_state_dict]
+ print(f"[Rank {local_rank}] Broadcasting weights for transformer...")
+ dist.broadcast_object_list(transformer_list, src=0)
+ # 20250423 pftq: Load broadcasted weights on all ranks. Skip redundant load_state_dict on rank 0
+ if local_rank != 0:
+ print(f"[Rank {local_rank}] Loading broadcasted transformer...")
+ transformer_state_dict = transformer_list[0]
+ self.transformer.load_state_dict(transformer_state_dict)
+ dist.barrier()
+ if offload:
+ print(f"[Rank {local_rank}] Moving transformer to cpu...")
+ self.transformer.cpu()
+ else:
+ print(f"[Rank {local_rank}] Moving transformer to {device}...")
+ self.transformer.to(device)
+ dist.barrier()
+ torch.cuda.empty_cache()
+
+ # 20250423 pftq: Broadcast text encoder weights from rank 0
+ if local_rank == 0:
+ print(f"[Rank {local_rank}] Loading text encoder to {broadcast_device}...")
+ self.text_encoder = get_text_encoder(model_path, broadcast_device, weight_dtype, skip_weights=False)
+ text_encoder_state_dict = self.text_encoder.state_dict()
+ else:
+ print(f"[Rank {local_rank}] Skipping text encoder load...")
+ self.text_encoder = get_text_encoder(model_path, broadcast_device, weight_dtype, skip_weights=True)
+ text_encoder_state_dict = None
+ dist.barrier() # Ensure rank 0 loads transformer and text encoder
+ print(f"[Rank {local_rank}] Broadcasting weights for text encoder...")
+ text_encoder_list = [text_encoder_state_dict]
+ dist.broadcast_object_list(text_encoder_list, src=0)
+ # 20250423 pftq: Load broadcasted weights on all ranks. Skip redundant load_state_dict on rank 0
+ if local_rank != 0:
+ print(f"[Rank {local_rank}] Loading broadcasted text encoder...")
+ text_encoder_state_dict = text_encoder_list[0]
+ self.text_encoder.load_state_dict(text_encoder_state_dict)
+ dist.barrier()
+ if offload:
+ print(f"[Rank {local_rank}] Moving text encoder to cpu...")
+ self.text_encoder.cpu()
+ else:
+ print(f"[Rank {local_rank}] Moving text encoder to {device}...")
+ self.text_encoder.to(device)
+ dist.barrier()
+ torch.cuda.empty_cache()
+
+ # 20250423 pftq: Stagger VAE loading across ranks
+ for rank in range(dist.get_world_size()):
+ if local_rank == rank:
+ print(f"[Rank {local_rank}] Loading VAE...")
+ self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
+ dist.barrier()
+
self.video_processor = VideoProcessor(vae_scale_factor=16)
- self.device = device
- self.offload = offload
if use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
@@ -67,8 +154,8 @@ def __init__(
for block in self.transformer.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
- self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
- self.sp_size = get_sequence_parallel_world_size()
+ self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
+ self.sp_size = get_sequence_parallel_world_size()
self.scheduler = FlowUniPCMultistepScheduler()
@@ -95,6 +182,27 @@ def encode_image(
predix_video_latent_length = prefix_video[0].shape[1]
return prefix_video, predix_video_latent_length
+ # 20250425 chaojie prompt travel & video input
+ def encode_video(
+ self, video: List[PipelineImageInput], height: int, width: int, num_frames: int
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ # prefix_video
+ prefix_video = np.array(video).transpose(3, 0, 1, 2)
+ prefix_video = torch.tensor(prefix_video) # .to(image_embeds.dtype).unsqueeze(1)
+ if prefix_video.dtype == torch.uint8:
+ prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
+ prefix_video = prefix_video.to(self.device)
+ prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
+ print(prefix_video[0].shape)
+ causal_block_size = self.transformer.num_frame_per_block
+ if prefix_video[0].shape[1] % causal_block_size != 0:
+ truncate_len = prefix_video[0].shape[1] % causal_block_size
+ print("the length of prefix video is truncated for the casual block size alignment.")
+ prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
+ predix_video_latent_length = prefix_video[0].shape[1]
+ return prefix_video, predix_video_latent_length
+
def prepare_latents(
self,
shape: Tuple[int],
@@ -183,9 +291,11 @@ def generate_timestep_matrix(
@torch.no_grad()
def __call__(
self,
- prompt: Union[str, List[str]],
+ #prompt: Union[str, List[str]],
+ prompt, # 20250425 chaojie prompt travel & video input
negative_prompt: Union[str, List[str]] = "",
image: PipelineImageInput = None,
+ video: List[PipelineImageInput] = None, # 20250425 chaojie prompt travel & video input
height: int = 480,
width: int = 832,
num_frames: int = 97,
@@ -199,6 +309,11 @@ def __call__(
ar_step: int = 5,
causal_block_size: int = None,
fps: int = 24,
+
+ # 20250425 chaojie prompt travel & video input
+ local_rank: int = 0,
+ save_dir: str = "",
+ video_out_file: str = "",
):
latent_height = height // 8
latent_width = width // 8
@@ -211,9 +326,22 @@ def __call__(
predix_video_latent_length = 0
if image:
prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames)
+ # 20250425 chaojie prompt travel & video input
+ elif video:
+ prefix_video, predix_video_latent_length = self.encode_video(video, height, width, num_frames)
self.text_encoder.to(self.device)
- prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype)
+ #prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype)
+ # 20250425 chaojie prompt travel & video input
+ prompt_embeds_list = []
+ if type(prompt) is list:
+ for prompt_iter in prompt:
+ prompt_embeds_list.append(self.text_encoder.encode(prompt_iter).to(self.transformer.dtype))
+ else:
+ prompt_embeds_list.append(self.text_encoder.encode(prompt).to(self.transformer.dtype))
+ prompt_embeds = prompt_embeds_list[0]
+ prompt_readable = ""
+
if self.do_classifier_free_guidance:
negative_prompt_embeds = self.text_encoder.encode(negative_prompt).to(self.transformer.dtype)
if self.offload:
@@ -316,8 +444,20 @@ def __call__(
n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
print(f"n_iter:{n_iter}")
output_video = None
- for i in range(n_iter):
- if output_video is not None: # i !=0
+ #for i in range(n_iter):
+ #if output_video is not None: # i !=0
+ # 20250425 chaojie prompt travel & video input
+ for i_n_iter in range(n_iter):
+ if type(prompt) is list:
+ if len(prompt) > i_n_iter:
+ prompt_embeds = prompt_embeds_list[i_n_iter]
+ if local_rank == 0:
+ partnum = i_n_iter + 1
+ if len(prompt) > i_n_iter:
+ prompt_readable = prompt[i_n_iter]
+ print(f"Generating part {partnum} of {n_iter}: "+prompt_readable) # 20250425 pftq
+ if output_video is not None: # i_n_iter !=0
+
prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device)
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
if prefix_video[0].shape[1] % causal_block_size != 0:
@@ -325,13 +465,15 @@ def __call__(
print("the length of prefix video is truncated for the casual block size alignment.")
prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
predix_video_latent_length = prefix_video[0].shape[1]
- finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
+ #finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
+ finished_frame_num = i_n_iter * (base_num_frames - overlap_history_frames) + overlap_history_frames # 20250425 chaojie prompt travel & video input
left_frame_num = latent_length - finished_frame_num
base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
if ar_step > 0 and self.transformer.enable_teacache:
num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step
self.transformer.num_steps = num_steps
- else: # i == 0
+ #else: # i == 0
+ else: # i_n_iter == 0 # 20250425 chaojie prompt travel & video input
base_num_frames_iter = base_num_frames
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
latents = self.prepare_latents(
@@ -420,6 +562,21 @@ def __call__(
output_video = torch.cat(
[output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1
) # c, f, h, w
+
+ # 20250425 chaojie prompt travel & video input
+ if local_rank == 0:
+ videonum = i_n_iter + 1
+ print(f"Saving partial video {videonum} of {n_iter}...") # 20250425 pftq
+ mid_output_video = output_video
+ mid_output_video = [(mid_output_video / 2 + 0.5).clamp(0, 1)]
+ mid_output_video = [video for video in mid_output_video]
+ mid_output_video = [video.permute(1, 2, 3, 0) * 255 for video in mid_output_video]
+ mid_output_video = [video.cpu().numpy().astype(np.uint8) for video in mid_output_video]
+
+ mid_video_out_file = video_out_file.replace(".mp4", f"_partial{i_n_iter}.mp4")
+ mid_output_path = os.path.join(save_dir, mid_video_out_file)
+ imageio.mimwrite(mid_output_path, mid_output_video[0], fps=fps, quality=8, output_params=["-loglevel", "error"])
+
output_video = [(output_video / 2 + 0.5).clamp(0, 1)]
output_video = [video for video in output_video]
output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video]
diff --git a/skyreels_v2_infer/pipelines/image2video_pipeline.py b/skyreels_v2_infer/pipelines/image2video_pipeline.py
index a260cf0..7bf05d6 100644
--- a/skyreels_v2_infer/pipelines/image2video_pipeline.py
+++ b/skyreels_v2_infer/pipelines/image2video_pipeline.py
@@ -39,15 +39,109 @@ class Image2VideoPipeline:
def __init__(
self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False
):
+ # 20250423 pftq: Fixed load time by broadcasting transformer and staggering text encoder, VAE, image encoder
+ import torch.distributed as dist
load_device = "cpu" if offload else device
- self.transformer = get_transformer(dit_path, load_device, weight_dtype)
- vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
- self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
- self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
- self.clip = get_image_encoder(model_path, load_device, weight_dtype)
- self.sp_size = 1
self.device = device
self.offload = offload
+
+ # 20250423 pftq: Check rank and distributed mode
+ if use_usp:
+ if not dist.is_initialized():
+ raise RuntimeError("Distributed environment must be initialized with dist.init_process_group before using use_usp=True")
+ local_rank = dist.get_rank()
+ else:
+ local_rank = 0
+
+ print(f"[Rank {local_rank}] Initializing pipeline components...")
+
+ vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
+ # 20250423 pftq: Load normally on single gpu
+ if not use_usp:
+ print(f"[Rank {local_rank}] Loading transformer to {load_device}...")
+ self.transformer = get_transformer(dit_path, load_device, weight_dtype, skip_weights=False)
+ print(f"[Rank {local_rank}] Loading text encoder to {load_device}...")
+ self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype, skip_weights=False)
+ print(f"[Rank {local_rank}] Loading VAE...")
+ self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
+
+ # 20250423 pftq: Broadcast transformer from rank 0
+ if use_usp:
+ broadcast_device = "cpu" # tested to be more stable to start with cpu broadcast even if you have an H100
+ if local_rank == 0:
+ print(f"[Rank {local_rank}] Loading transformer to {broadcast_device}...")
+ self.transformer = get_transformer(dit_path, broadcast_device, weight_dtype, skip_weights=False)
+ transformer_state_dict = self.transformer.state_dict()
+ else:
+ print(f"[Rank {local_rank}] Skipping transformer load...")
+ self.transformer = get_transformer(dit_path, broadcast_device, weight_dtype, skip_weights=True)
+ transformer_state_dict = None
+ dist.barrier() # Ensure rank 0 loads transformer and text encoder
+ transformer_list = [transformer_state_dict]
+ print(f"[Rank {local_rank}] Broadcasting weights for transformer...")
+ dist.broadcast_object_list(transformer_list, src=0)
+ # 20250423 pftq: Load broadcasted weights on all ranks. Skip redundant load_state_dict on rank 0
+ if local_rank != 0:
+ print(f"[Rank {local_rank}] Loading broadcasted transformer...")
+ transformer_state_dict = transformer_list[0]
+ self.transformer.load_state_dict(transformer_state_dict)
+ dist.barrier()
+ if offload:
+ print(f"[Rank {local_rank}] Moving transformer to cpu...")
+ self.transformer.cpu()
+ else:
+ print(f"[Rank {local_rank}] Moving transformer to {device}...")
+ self.transformer.to(device)
+ dist.barrier()
+ torch.cuda.empty_cache()
+
+ # 20250423 pftq: Broadcast text encoder weights from rank 0
+ if local_rank == 0:
+ print(f"[Rank {local_rank}] Loading text encoder to {broadcast_device}...")
+ self.text_encoder = get_text_encoder(model_path, broadcast_device, weight_dtype, skip_weights=False)
+ text_encoder_state_dict = self.text_encoder.state_dict()
+ else:
+ print(f"[Rank {local_rank}] Skipping text encoder load...")
+ self.text_encoder = get_text_encoder(model_path, broadcast_device, weight_dtype, skip_weights=True)
+ text_encoder_state_dict = None
+ dist.barrier() # Ensure rank 0 loads transformer and text encoder
+ print(f"[Rank {local_rank}] Broadcasting weights for text encoder...")
+ text_encoder_list = [text_encoder_state_dict]
+ dist.broadcast_object_list(text_encoder_list, src=0)
+ # 20250423 pftq: Load broadcasted weights on all ranks. Skip redundant load_state_dict on rank 0
+ if local_rank != 0:
+ print(f"[Rank {local_rank}] Loading broadcasted text encoder...")
+ text_encoder_state_dict = text_encoder_list[0]
+ self.text_encoder.load_state_dict(text_encoder_state_dict)
+ dist.barrier()
+ if offload:
+ print(f"[Rank {local_rank}] Moving text encoder to cpu...")
+ self.text_encoder.cpu()
+ else:
+ print(f"[Rank {local_rank}] Moving text encoder to {device}...")
+ self.text_encoder.to(device)
+ dist.barrier()
+ torch.cuda.empty_cache()
+
+ # 20250423 pftq: Stagger VAE loading across ranks
+ for rank in range(dist.get_world_size()):
+ if local_rank == rank:
+ print(f"[Rank {local_rank}] Loading VAE...")
+ self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
+ dist.barrier()
+
+ # 20250423 pftq: Stagger image encoder loading across ranks
+ if use_usp:
+ for rank in range(dist.get_world_size()):
+ if local_rank == rank:
+ print(f"[Rank {local_rank}] Loading image encoder...")
+ self.clip = get_image_encoder(model_path, load_device, weight_dtype)
+ dist.barrier()
+ else:
+ print(f"[Rank {local_rank}] Loading image encoder...")
+ self.clip = get_image_encoder(model_path, load_device, weight_dtype)
+
+ self.sp_size = 1
self.video_processor = VideoProcessor(vae_scale_factor=16)
if use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
@@ -56,8 +150,9 @@ def __init__(
for block in self.transformer.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
+ # 20250423 pftq: Fixed indentation and removed duplicate forward assignment
self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
- self.sp_size = get_sequence_parallel_world_size()
+ self.sp_size = get_sequence_parallel_world_size()
self.scheduler = FlowUniPCMultistepScheduler()
self.vae_stride = (4, 8, 8)
@@ -133,7 +228,7 @@ def __call__(
"y": y,
}
- self.transformer.to(self.device)
+ #self.transformer.to(self.device) # 20250425 pftq: loaded twice
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = torch.stack([latent]).to(self.device)
timestep = torch.stack([t]).to(self.device)
diff --git a/skyreels_v2_infer/pipelines/text2video_pipeline.py b/skyreels_v2_infer/pipelines/text2video_pipeline.py
index 05a1dd3..dbe3a80 100644
--- a/skyreels_v2_infer/pipelines/text2video_pipeline.py
+++ b/skyreels_v2_infer/pipelines/text2video_pipeline.py
@@ -18,15 +18,99 @@ class Text2VideoPipeline:
def __init__(
self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False
):
+ # 20250423 pftq: Fixed load time by broadcasting transformer and staggering text encoder, VAE
+ import torch.distributed as dist
load_device = "cpu" if offload else device
- self.transformer = get_transformer(dit_path, load_device, weight_dtype)
+ self.device = device
+ self.offload = offload
+
+ # 20250423 pftq: Check rank and distributed mode
+ if use_usp:
+ if not dist.is_initialized():
+ raise RuntimeError("Distributed environment must be initialized with dist.init_process_group before using use_usp=True")
+ local_rank = dist.get_rank()
+ else:
+ local_rank = 0
+
+ print(f"[Rank {local_rank}] Initializing pipeline components...")
+
vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
- self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
- self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
+ # 20250423 pftq: Load normally on single gpu
+ if not use_usp:
+ print(f"[Rank {local_rank}] Loading transformer to {load_device}...")
+ self.transformer = get_transformer(dit_path, load_device, weight_dtype, skip_weights=False)
+ print(f"[Rank {local_rank}] Loading text encoder to {load_device}...")
+ self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype, skip_weights=False)
+ print(f"[Rank {local_rank}] Loading VAE...")
+ self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
+
+ # 20250423 pftq: Broadcast transformer from rank 0
+ if use_usp:
+ broadcast_device = "cpu" # tested to be more stable to start with cpu broadcast even if you have an H100
+ if local_rank == 0:
+ print(f"[Rank {local_rank}] Loading transformer to {broadcast_device}...")
+ self.transformer = get_transformer(dit_path, broadcast_device, weight_dtype, skip_weights=False)
+ transformer_state_dict = self.transformer.state_dict()
+ else:
+ print(f"[Rank {local_rank}] Skipping transformer load...")
+ self.transformer = get_transformer(dit_path, broadcast_device, weight_dtype, skip_weights=True)
+ transformer_state_dict = None
+ dist.barrier() # Ensure rank 0 loads transformer and text encoder
+ transformer_list = [transformer_state_dict]
+ print(f"[Rank {local_rank}] Broadcasting weights for transformer...")
+ dist.broadcast_object_list(transformer_list, src=0)
+ # 20250423 pftq: Load broadcasted weights on all ranks. Skip redundant load_state_dict on rank 0
+ if local_rank != 0:
+ print(f"[Rank {local_rank}] Loading broadcasted transformer...")
+ transformer_state_dict = transformer_list[0]
+ self.transformer.load_state_dict(transformer_state_dict)
+ dist.barrier()
+ if offload:
+ print(f"[Rank {local_rank}] Moving transformer to cpu...")
+ self.transformer.cpu()
+ else:
+ print(f"[Rank {local_rank}] Moving transformer to {device}...")
+ self.transformer.to(device)
+ dist.barrier()
+ torch.cuda.empty_cache()
+
+ # 20250423 pftq: Broadcast text encoder weights from rank 0
+ if local_rank == 0:
+ print(f"[Rank {local_rank}] Loading text encoder to {broadcast_device}...")
+ self.text_encoder = get_text_encoder(model_path, broadcast_device, weight_dtype, skip_weights=False)
+ text_encoder_state_dict = self.text_encoder.state_dict()
+ else:
+ print(f"[Rank {local_rank}] Skipping text encoder load...")
+ self.text_encoder = get_text_encoder(model_path, broadcast_device, weight_dtype, skip_weights=True)
+ text_encoder_state_dict = None
+ dist.barrier() # Ensure rank 0 loads transformer and text encoder
+ print(f"[Rank {local_rank}] Broadcasting weights for text encoder...")
+ text_encoder_list = [text_encoder_state_dict]
+ dist.broadcast_object_list(text_encoder_list, src=0)
+ # 20250423 pftq: Load broadcasted weights on all ranks. Skip redundant load_state_dict on rank 0
+ if local_rank != 0:
+ print(f"[Rank {local_rank}] Loading broadcasted text encoder...")
+ text_encoder_state_dict = text_encoder_list[0]
+ self.text_encoder.load_state_dict(text_encoder_state_dict)
+ dist.barrier()
+ if offload:
+ print(f"[Rank {local_rank}] Moving text encoder to cpu...")
+ self.text_encoder.cpu()
+ else:
+ print(f"[Rank {local_rank}] Moving text encoder to {device}...")
+ self.text_encoder.to(device)
+ dist.barrier()
+ torch.cuda.empty_cache()
+
+ # 20250423 pftq: Stagger VAE loading across ranks
+ for rank in range(dist.get_world_size()):
+ if local_rank == rank:
+ print(f"[Rank {local_rank}] Loading VAE...")
+ self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
+ dist.barrier()
+
self.video_processor = VideoProcessor(vae_scale_factor=16)
self.sp_size = 1
- self.device = device
- self.offload = offload
if use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
@@ -34,8 +118,9 @@ def __init__(
for block in self.transformer.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
+ # 20250423 pftq: Fixed indentation and removed duplicate forward assignment
self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
- self.sp_size = get_sequence_parallel_world_size()
+ self.sp_size = get_sequence_parallel_world_size()
self.scheduler = FlowUniPCMultistepScheduler()
self.vae_stride = (4, 8, 8)