Skip to content

Commit e3d0cbe

Browse files
[STA] Implement mask search for V1's Wan2.1 (#415)
Co-authored-by: BrianChen1129 <[email protected]>
1 parent d5ec468 commit e3d0cbe

File tree

9 files changed

+803
-34
lines changed

9 files changed

+803
-34
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# STA Mask Search Examples
2+
3+
```bash
4+
bash examples/inference/sta_mask_search/inference_wan_sta.sh
5+
```
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/bin/bash
2+
3+
export FASTVIDEO_ATTENTION_CONFIG=assets/mask_strategy_wan.json
4+
export FASTVIDEO_ATTENTION_BACKEND=SLIDING_TILE_ATTN
5+
export MODEL_BASE=Wan-AI/Wan2.1-T2V-14B-Diffusers
6+
7+
base_port=29503
8+
num_gpu=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l)
9+
gpu_ids=$(seq 0 $((num_gpu-1)))
10+
skip_time_steps=12
11+
12+
output_path="inference_results/sta/mask_search_full"
13+
STA_mode="STA_searching"
14+
for i in $gpu_ids; do
15+
port=$((base_port+i))
16+
CUDA_VISIBLE_DEVICES=$i MASTER_PORT=$port python examples/inference/sta_mask_search/wan_example.py \
17+
--prompt_path ./assets/prompt_extend_${i}.txt \
18+
--output_path $output_path \
19+
--STA_mode $STA_mode &
20+
sleep 1
21+
done
22+
wait
23+
echo "STA searching completed"
24+
25+
output_path="inference_results/sta/mask_search_sparse"
26+
STA_mode="STA_tuning"
27+
for i in $gpu_ids; do
28+
port=$((base_port+i))
29+
CUDA_VISIBLE_DEVICES=$i MASTER_PORT=$port python examples/inference/sta_mask_search/wan_example.py \
30+
--prompt_path ./assets/prompt_extend_${i}.txt \
31+
--output_path $output_path \
32+
--STA_mode $STA_mode \
33+
--skip_time_steps $skip_time_steps &
34+
sleep 1
35+
done
36+
wait
37+
echo "STA tuning completed"
38+
39+
echo "All jobs completed"
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
import argparse
3+
from fastvideo import VideoGenerator, SamplingParam
4+
5+
def main(args):
6+
os.makedirs(args.output_path, exist_ok=True)
7+
# Create a video generator with a pre-trained model
8+
generator = VideoGenerator.from_pretrained(
9+
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
10+
num_gpus=args.num_gpus, # Adjust based on your hardware
11+
STA_mode=args.STA_mode,
12+
skip_time_steps=args.skip_time_steps
13+
)
14+
15+
# Prompts for your video
16+
prompt = args.prompt
17+
prompt_path = args.prompt_path
18+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
19+
20+
if prompt_path is not None:
21+
with open(prompt_path, "r") as f:
22+
prompts = f.readlines()
23+
else:
24+
prompts = [prompt]
25+
26+
params = SamplingParam(
27+
height=args.height,
28+
width=args.width,
29+
num_frames=args.num_frames,
30+
num_inference_steps=args.num_inference_steps,
31+
fps=args.fps,
32+
guidance_scale=args.guidance_scale,
33+
seed=args.seed,
34+
return_frames=True, # Also return frames from this call (defaults to False)
35+
output_path=args.output_path, # Controls where videos are saved
36+
save_video=True,
37+
negative_prompt=negative_prompt
38+
)
39+
40+
# Generate the video
41+
for prompt in prompts:
42+
video = generator.generate_video(
43+
prompt,
44+
sampling_param=params,
45+
)
46+
47+
if __name__ == '__main__':
48+
parser = argparse.ArgumentParser()
49+
parser.add_argument("--prompt", type=str, default="A man is dancing.")
50+
parser.add_argument("--prompt_path", type=str, default=None)
51+
parser.add_argument("--height", type=int, default=768)
52+
parser.add_argument("--width", type=int, default=1280)
53+
parser.add_argument("--num_frames", type=int, default=69)
54+
parser.add_argument("--num_inference_steps", type=int, default=50)
55+
parser.add_argument("--fps", type=int, default=16)
56+
parser.add_argument("--guidance_scale", type=float, default=5.0)
57+
parser.add_argument("--seed", type=int, default=12345)
58+
parser.add_argument("--output_path", type=str, default="my_videos/")
59+
parser.add_argument("--num_gpus", type=int, default=1)
60+
parser.add_argument("--STA_mode", type=str, default="STA_searching")
61+
parser.add_argument("--skip_time_steps", type=int, default=12)
62+
args = parser.parse_args()
63+
main(args)

0 commit comments

Comments
 (0)