Skip to content

Commit 66fdcc8

Browse files
[Training] Use inference pipeline for training validation (#585)
1 parent ed1e8d6 commit 66fdcc8

34 files changed

+544
-246
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
This directory contain e2e examples scripts for finetuning Wan2.1 I2V.
2+
3+
Execute the following commands from `FastVideo/` to run training:
4+
5+
- Download crush-smol dataset:
6+
`bash examples/training/finetune/wan_i2v_14b_480p/crush_smol/download_dataset.sh`
7+
- Preprocess the videos and captions into latents:
8+
`bash examples/training/finetune/wan_i2v_14b_480p/crush_smol/preprocess_wan_data_i2v.sh`
9+
- Edit the following file and run finetuning:
10+
`bash examples/training/finetune/wan_i2v_14b_480p/crush_smol/finetune_i2v.sh`
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/bin/bash
2+
3+
python scripts/huggingface/download_hf.py --repo_id "wlsaidhi/crush-smol-merged" --local_dir "data/crush-smol" --repo_type "dataset"
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#!/bin/bash
2+
3+
export WANDB_BASE_URL="https://api.wandb.ai"
4+
export WANDB_MODE=online
5+
export TOKENIZERS_PARALLELISM=false
6+
# export FASTVIDEO_ATTENTION_BACKEND=TORCH_SDPA
7+
8+
MODEL_PATH="weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers"
9+
DATA_DIR="data/crush-smol_processed_i2v_1_3b_inp/combined_parquet_dataset/"
10+
VALIDATION_DATASET_FILE="examples/training/finetune/Wan2.1-Fun-1.3B-InP/crush_smol/validation.json"
11+
NUM_GPUS=8
12+
# export CUDA_VISIBLE_DEVICES=4,5
13+
# IP=[MASTER NODE IP]
14+
15+
# Training arguments
16+
training_args=(
17+
--tracker_project_name "wan_i2v_finetune"
18+
--output_dir "$DATA_DIR/outputs/wan_i2v_finetune"
19+
--max_train_steps 2000
20+
--train_batch_size 4
21+
--train_sp_batch_size 1
22+
--gradient_accumulation_steps 1
23+
--num_latent_t 8
24+
--num_height 480
25+
--num_width 832
26+
--num_frames 77
27+
)
28+
29+
# Parallel arguments
30+
parallel_args=(
31+
--num_gpus $NUM_GPUS
32+
--sp_size 4
33+
--tp_size 4
34+
--hsdp_replicate_dim 2
35+
--hsdp_shard_dim 4
36+
)
37+
38+
# Model arguments
39+
model_args=(
40+
--model_path $MODEL_PATH
41+
--pretrained_model_name_or_path $MODEL_PATH
42+
)
43+
44+
# Dataset arguments
45+
dataset_args=(
46+
--data_path "$DATA_DIR"
47+
--dataloader_num_workers 1
48+
)
49+
50+
# Validation arguments
51+
validation_args=(
52+
--log_validation
53+
--validation_dataset_file "$VALIDATION_DATASET_FILE"
54+
--validation_steps 100
55+
--validation_sampling_steps "40"
56+
--validation_guidance_scale "1.0"
57+
)
58+
59+
# Optimizer arguments
60+
optimizer_args=(
61+
--learning_rate 2e-5
62+
--mixed_precision "bf16"
63+
--checkpointing_steps 2000
64+
--weight_decay 1e-4
65+
--max_grad_norm 1.0
66+
)
67+
68+
# Miscellaneous arguments
69+
miscellaneous_args=(
70+
--inference_mode False
71+
--allow_tf32
72+
--checkpoints_total_limit 3
73+
--training_cfg_rate 0.1
74+
--multi_phased_distill_schedule "4000-1"
75+
--not_apply_cfg_solver
76+
--dit_precision "fp32"
77+
--num_euler_timesteps 50
78+
--ema_start_step 0
79+
--enable_gradient_checkpointing_type "full"
80+
)
81+
82+
# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t
83+
torchrun \
84+
--nnodes 1 \
85+
--nproc_per_node $NUM_GPUS \
86+
fastvideo/v1/training/wan_i2v_training_pipeline.py \
87+
"${parallel_args[@]}" \
88+
"${model_args[@]}" \
89+
"${dataset_args[@]}" \
90+
"${training_args[@]}" \
91+
"${optimizer_args[@]}" \
92+
"${validation_args[@]}" \
93+
"${miscellaneous_args[@]}"
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=i2v
3+
#SBATCH --partition=main
4+
#SBATCH --qos=hao
5+
#SBATCH --nodes=4
6+
#SBATCH --ntasks=4
7+
#SBATCH --ntasks-per-node=1
8+
#SBATCH --gres=gpu:8
9+
#SBATCH --cpus-per-task=128
10+
#SBATCH --nodelist=fs-mbz-gpu-[100-850]
11+
#SBATCH --mem=1440G
12+
#SBATCH --output=i2v_output/i2v_%j.out
13+
#SBATCH --error=i2v_output/i2v_%j.err
14+
#SBATCH --exclusive
15+
set -e -x
16+
17+
# Environment Setup
18+
source ~/conda/miniconda/bin/activate
19+
conda activate will-fv
20+
21+
# Basic Info
22+
export WANDB_MODE="online"
23+
export NCCL_P2P_DISABLE=1
24+
export TORCH_NCCL_ENABLE_MONITORING=0
25+
# different cache dir for different processes
26+
export TRITON_CACHE_DIR=/tmp/triton_cache_${SLURM_PROCID}
27+
export MASTER_PORT=29500
28+
export NODE_RANK=$SLURM_PROCID
29+
nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) )
30+
export MASTER_ADDR=${nodes[0]}
31+
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
32+
export TOKENIZERS_PARALLELISM=false
33+
export WANDB_BASE_URL="https://api.wandb.ai"
34+
export WANDB_MODE=online
35+
# export FASTVIDEO_ATTENTION_BACKEND=TORCH_SDPA
36+
37+
echo "MASTER_ADDR: $MASTER_ADDR"
38+
echo "NODE_RANK: $NODE_RANK"
39+
40+
# Configs
41+
NUM_GPUS=8
42+
43+
MODEL_PATH="weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers"
44+
DATA_DIR="data/crush-smol_processed_i2v_1_3b_inp/combined_parquet_dataset/"
45+
VALIDATION_DATASET_FILE="examples/training/finetune/Wan2.1-Fun-1.3B-InP/crush_smol/validation.json"
46+
# export CUDA_VISIBLE_DEVICES=4,5
47+
# IP=[MASTER NODE IP]
48+
49+
# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t
50+
51+
# Training arguments
52+
training_args=(
53+
--tracker_project_name wan_i2v_finetune
54+
--output_dir="$DATA_DIR/outputs/wan_i2v_finetune_2n"
55+
--max_train_steps=2000
56+
--train_batch_size=2
57+
--train_sp_batch_size 1
58+
--gradient_accumulation_steps=1
59+
--num_latent_t 8
60+
--num_height 480
61+
--num_width 832
62+
--num_frames 77
63+
)
64+
65+
# Parallel arguments
66+
parallel_args=(
67+
--num_gpus $NUM_GPUS
68+
--sp_size $NUM_GPUS
69+
--tp_size $NUM_GPUS
70+
--hsdp_replicate_dim $SLURM_JOB_NUM_NODES
71+
--hsdp_shard_dim $NUM_GPUS
72+
)
73+
74+
# Model arguments
75+
model_args=(
76+
--model_path $MODEL_PATH
77+
--pretrained_model_name_or_path $MODEL_PATH
78+
)
79+
80+
# Dataset arguments
81+
dataset_args=(
82+
--data_path "$DATA_DIR"
83+
--dataloader_num_workers 10
84+
)
85+
86+
# Validation arguments
87+
validation_args=(
88+
--log_validation
89+
--validation_dataset_file "$VALIDATION_DATASET_FILE"
90+
--validation_steps 100
91+
--validation_sampling_steps "40"
92+
--validation_guidance_scale "1.0"
93+
)
94+
95+
# Optimizer arguments
96+
optimizer_args=(
97+
--learning_rate=1e-5
98+
--mixed_precision="bf16"
99+
--checkpointing_steps=1000
100+
--weight_decay 1e-4
101+
--max_grad_norm 1.0
102+
)
103+
104+
# Miscellaneous arguments
105+
miscellaneous_args=(
106+
--inference_mode False
107+
--allow_tf32
108+
--checkpoints_total_limit 3
109+
--training_cfg_rate 0.1
110+
--multi_phased_distill_schedule "4000-1"
111+
--not_apply_cfg_solver
112+
--dit_precision "fp32"
113+
--num_euler_timesteps 50
114+
--ema_start_step 0
115+
--enable_gradient_checkpointing_type "full"
116+
)
117+
118+
srun torchrun \
119+
--nnodes $SLURM_JOB_NUM_NODES \
120+
--nproc_per_node $NUM_GPUS \
121+
--node_rank $SLURM_PROCID \
122+
--rdzv_backend=c10d \
123+
--rdzv_endpoint="$MASTER_ADDR:$MASTER_PORT" \
124+
fastvideo/v1/training/wan_i2v_training_pipeline.py \
125+
"${parallel_args[@]}" \
126+
"${model_args[@]}" \
127+
"${dataset_args[@]}" \
128+
"${training_args[@]}" \
129+
"${optimizer_args[@]}" \
130+
"${validation_args[@]}" \
131+
"${miscellaneous_args[@]}"
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/bin/bash
2+
3+
GPU_NUM=1 # 2,4,8
4+
MODEL_PATH="weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers"
5+
MODEL_TYPE="wan"
6+
DATA_MERGE_PATH="data/crush-smol/merge.txt"
7+
OUTPUT_DIR="data/crush-smol_processed_i2v_1_3b_inp/"
8+
VALIDATION_PATH="examples/training/finetune/Wan2.1-Fun-1.3B-InP/crush_smol/validation.json"
9+
10+
torchrun --nproc_per_node=$GPU_NUM \
11+
fastvideo/v1/pipelines/preprocess/v1_preprocess.py \
12+
--model_path $MODEL_PATH \
13+
--data_merge_path $DATA_MERGE_PATH \
14+
--preprocess_video_batch_size 8 \
15+
--max_height 480 \
16+
--max_width 832 \
17+
--num_frames 77 \
18+
--dataloader_num_workers 0 \
19+
--output_dir=$OUTPUT_DIR \
20+
--model_type $MODEL_TYPE \
21+
--train_fps 16 \
22+
--validation_dataset_file $VALIDATION_PATH \
23+
--samples_per_file 8 \
24+
--flush_frequency 8 \
25+
--video_length_tolerance_range 5 \
26+
--preprocess_task "i2v"
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"data": [
3+
{
4+
"caption": "A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.",
5+
"image_path": null,
6+
"video_path": "validation_dataset/yYcK4nANZz4-Scene-034.mp4",
7+
"num_inference_steps": 40,
8+
"height": 480,
9+
"width": 832,
10+
"num_frames": 77
11+
},
12+
{
13+
"caption": "A large metal cylinder is seen compressing colorful clay into a compact shape, demonstrating the power of a hydraulic press.",
14+
"image_path": null,
15+
"video_path": "validation_dataset/yYcK4nANZz4-Scene-027.mp4",
16+
"num_inference_steps": 40,
17+
"height": 480,
18+
"width": 832,
19+
"num_frames": 77
20+
},
21+
{
22+
"caption": "A large metal cylinder is seen pressing down on a pile of colorful candies, flattening them as if they were under a hydraulic press. The candies are crushed and broken into small pieces, creating a mess on the table.",
23+
"image_path": null,
24+
"video_path": "validation_dataset/yYcK4nANZz4-Scene-030.mp4",
25+
"num_inference_steps": 40,
26+
"height": 480,
27+
"width": 832,
28+
"num_frames": 77
29+
}
30+
]
31+
}

examples/training/finetune/wan_i2v_14b_480p/crush_smol/finetune_i2v.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
export WANDB_BASE_URL="https://api.wandb.ai"
44
export WANDB_MODE=online
5+
export TOKENIZERS_PARALLELISM=false
56
# export FASTVIDEO_ATTENTION_BACKEND=TORCH_SDPA
67

78
MODEL_PATH="Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
89
DATA_DIR="data/crush-smol_processed_i2v/combined_parquet_dataset/"
9-
VALIDATION_DIR="data/crush-smol_processed_i2v/validation_parquet_dataset/"
10+
VALIDATION_DATASET_FILE="examples/training/finetune/wan_i2v_14b_480p/crush_smol/validation.json"
1011
NUM_GPUS=8
1112
# export CUDA_VISIBLE_DEVICES=4,5
1213
# IP=[MASTER NODE IP]
@@ -49,7 +50,7 @@ dataset_args=(
4950
# Validation arguments
5051
validation_args=(
5152
--log_validation
52-
--validation_preprocessed_path "$VALIDATION_DIR"
53+
--validation_dataset_file "$VALIDATION_DATASET_FILE"
5354
--validation_steps 100
5455
--validation_sampling_steps "40"
5556
--validation_guidance_scale "1.0"
@@ -75,6 +76,7 @@ miscellaneous_args=(
7576
--dit_precision "fp32"
7677
--num_euler_timesteps 50
7778
--ema_start_step 0
79+
--enable_gradient_checkpointing_type "full"
7880
)
7981

8082
# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t

0 commit comments

Comments
 (0)