Skip to content

Commit 46a0a85

Browse files
[Training] Fixes SP for training; Improve Datasets and schema (#555)
1 parent 4aeabbc commit 46a0a85

26 files changed

+930
-211
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/bin/bash
2+
3+
python scripts/huggingface/download_hf.py --repo_id "wlsaidhi/crush-smol-merged" --local_dir "data/crush-smol" --repo_type "dataset"
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#!/bin/bash
2+
3+
export WANDB_BASE_URL="https://api.wandb.ai"
4+
export WANDB_MODE=online
5+
# export FASTVIDEO_ATTENTION_BACKEND=TORCH_SDPA
6+
7+
MODEL_PATH="Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
8+
DATA_DIR="data/crush-smol_processed_i2v/combined_parquet_dataset/"
9+
VALIDATION_DIR="data/crush-smol_processed_i2v/validation_parquet_dataset/"
10+
NUM_GPUS=8
11+
# export CUDA_VISIBLE_DEVICES=4,5
12+
# IP=[MASTER NODE IP]
13+
14+
# Training arguments
15+
training_args=(
16+
--tracker_project_name "wan_i2v_finetune"
17+
--output_dir "$DATA_DIR/outputs/wan_i2v_finetune"
18+
--max_train_steps 2000
19+
--train_batch_size 1
20+
--train_sp_batch_size 1
21+
--gradient_accumulation_steps 1
22+
--num_latent_t 8
23+
--num_height 480
24+
--num_width 832
25+
--num_frames 77
26+
)
27+
28+
# Parallel arguments
29+
parallel_args=(
30+
--num_gpus $NUM_GPUS
31+
--sp_size 8
32+
--tp_size 8
33+
--hsdp_replicate_dim 1
34+
--hsdp_shard_dim 8
35+
)
36+
37+
# Model arguments
38+
model_args=(
39+
--model_path $MODEL_PATH
40+
--pretrained_model_name_or_path $MODEL_PATH
41+
)
42+
43+
# Dataset arguments
44+
dataset_args=(
45+
--data_path "$DATA_DIR"
46+
--dataloader_num_workers 1
47+
)
48+
49+
# Validation arguments
50+
validation_args=(
51+
--log_validation
52+
--validation_preprocessed_path "$VALIDATION_DIR"
53+
--validation_steps 100
54+
--validation_sampling_steps "40"
55+
--validation_guidance_scale "1.0"
56+
)
57+
58+
# Optimizer arguments
59+
optimizer_args=(
60+
--learning_rate 1e-5
61+
--mixed_precision "bf16"
62+
--checkpointing_steps 1000
63+
--weight_decay 1e-4
64+
--max_grad_norm 1.0
65+
)
66+
67+
# Miscellaneous arguments
68+
miscellaneous_args=(
69+
--inference_mode False
70+
--allow_tf32
71+
--checkpoints_total_limit 3
72+
--cfg 0.0
73+
--multi_phased_distill_schedule "4000-1"
74+
--not_apply_cfg_solver
75+
--dit_precision "fp32"
76+
--num_euler_timesteps 50
77+
--ema_start_step 0
78+
)
79+
80+
# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t
81+
torchrun \
82+
--nnodes 1 \
83+
--nproc_per_node $NUM_GPUS \
84+
fastvideo/v1/training/wan_i2v_training_pipeline.py \
85+
"${parallel_args[@]}" \
86+
"${model_args[@]}" \
87+
"${dataset_args[@]}" \
88+
"${training_args[@]}" \
89+
"${optimizer_args[@]}" \
90+
"${validation_args[@]}" \
91+
"${miscellaneous_args[@]}"
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=i2v
3+
#SBATCH --partition=main
4+
#SBATCH --qos=hao
5+
#SBATCH --nodes=4
6+
#SBATCH --ntasks=4
7+
#SBATCH --ntasks-per-node=1
8+
#SBATCH --gres=gpu:8
9+
#SBATCH --cpus-per-task=128
10+
#SBATCH --nodelist=fs-mbz-gpu-[100-850]
11+
#SBATCH --mem=1440G
12+
#SBATCH --output=i2v_output/i2v_%j.out
13+
#SBATCH --error=i2v_output/i2v_%j.err
14+
#SBATCH --exclusive
15+
set -e -x
16+
17+
# Environment Setup
18+
source ~/conda/miniconda/bin/activate
19+
conda activate will-fv
20+
21+
# Basic Info
22+
export WANDB_MODE="online"
23+
export NCCL_P2P_DISABLE=1
24+
export TORCH_NCCL_ENABLE_MONITORING=0
25+
# different cache dir for different processes
26+
export TRITON_CACHE_DIR=/tmp/triton_cache_${SLURM_PROCID}
27+
export MASTER_PORT=29500
28+
export NODE_RANK=$SLURM_PROCID
29+
nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) )
30+
export MASTER_ADDR=${nodes[0]}
31+
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
32+
export TOKENIZERS_PARALLELISM=false
33+
export WANDB_BASE_URL="https://api.wandb.ai"
34+
export WANDB_MODE=online
35+
# export FASTVIDEO_ATTENTION_BACKEND=TORCH_SDPA
36+
37+
echo "MASTER_ADDR: $MASTER_ADDR"
38+
echo "NODE_RANK: $NODE_RANK"
39+
40+
# Configs
41+
NUM_GPUS=8
42+
43+
MODEL_PATH="Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
44+
DATA_DIR="data/crush-smol_processed_i2v/combined_parquet_dataset/"
45+
VALIDATION_DIR="data/crush-smol_processed_i2v/validation_parquet_dataset/"
46+
# export CUDA_VISIBLE_DEVICES=4,5
47+
# IP=[MASTER NODE IP]
48+
49+
# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t
50+
51+
# Training arguments
52+
training_args=(
53+
--tracker_project_name wan_i2v_finetune
54+
--output_dir="$DATA_DIR/outputs/wan_i2v_finetune_2n"
55+
--max_train_steps=2000
56+
--train_batch_size=2
57+
--train_sp_batch_size 1
58+
--gradient_accumulation_steps=1
59+
--num_latent_t 8
60+
--num_height 480
61+
--num_width 832
62+
--num_frames 77
63+
)
64+
65+
# Parallel arguments
66+
parallel_args=(
67+
--num_gpus $NUM_GPUS
68+
--sp_size $NUM_GPUS
69+
--tp_size $NUM_GPUS
70+
--hsdp_replicate_dim $SLURM_JOB_NUM_NODES
71+
--hsdp_shard_dim $NUM_GPUS
72+
)
73+
74+
# Model arguments
75+
model_args=(
76+
--model_path $MODEL_PATH
77+
--pretrained_model_name_or_path $MODEL_PATH
78+
)
79+
80+
# Dataset arguments
81+
dataset_args=(
82+
--data_path "$DATA_DIR"
83+
--dataloader_num_workers 10
84+
)
85+
86+
# Validation arguments
87+
validation_args=(
88+
--log_validation
89+
--validation_preprocessed_path "$VALIDATION_DIR"
90+
--validation_steps 100
91+
--validation_sampling_steps "40"
92+
--validation_guidance_scale "1.0"
93+
)
94+
95+
# Optimizer arguments
96+
optimizer_args=(
97+
--learning_rate=1e-5
98+
--mixed_precision="bf16"
99+
--checkpointing_steps=1000
100+
--weight_decay 1e-4
101+
--max_grad_norm 1.0
102+
)
103+
104+
# Miscellaneous arguments
105+
miscellaneous_args=(
106+
--inference_mode False
107+
--allow_tf32
108+
--checkpoints_total_limit 3
109+
--cfg 0.0
110+
--multi_phased_distill_schedule "4000-1"
111+
--not_apply_cfg_solver
112+
--dit_precision "fp32"
113+
--num_euler_timesteps 50
114+
--ema_start_step 0
115+
)
116+
117+
srun torchrun \
118+
--nnodes $SLURM_JOB_NUM_NODES \
119+
--nproc_per_node $NUM_GPUS \
120+
--node_rank $SLURM_PROCID \
121+
--rdzv_backend=c10d \
122+
--rdzv_endpoint="$MASTER_ADDR:$MASTER_PORT" \
123+
fastvideo/v1/training/wan_i2v_training_pipeline.py \
124+
"${parallel_args[@]}" \
125+
"${model_args[@]}" \
126+
"${dataset_args[@]}" \
127+
"${training_args[@]}" \
128+
"${optimizer_args[@]}" \
129+
"${validation_args[@]}" \
130+
"${miscellaneous_args[@]}"
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/bin/bash
2+
3+
GPU_NUM=1 # 2,4,8
4+
MODEL_PATH="Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
5+
MODEL_TYPE="wan"
6+
DATA_MERGE_PATH="data/crush-smol/merge.txt"
7+
OUTPUT_DIR="data/crush-smol_processed_i2v/"
8+
VALIDATION_PATH="examples/training/finetune/wan_i2v_14b_480p/crush_smol/validation.json"
9+
10+
torchrun --nproc_per_node=$GPU_NUM \
11+
fastvideo/v1/pipelines/preprocess/v1_preprocess.py \
12+
--model_path $MODEL_PATH \
13+
--data_merge_path $DATA_MERGE_PATH \
14+
--preprocess_video_batch_size 8 \
15+
--max_height 480 \
16+
--max_width 832 \
17+
--num_frames 77 \
18+
--dataloader_num_workers 0 \
19+
--output_dir=$OUTPUT_DIR \
20+
--model_type $MODEL_TYPE \
21+
--train_fps 16 \
22+
--validation_dataset_file $VALIDATION_PATH \
23+
--samples_per_file 8 \
24+
--flush_frequency 8 \
25+
--preprocess_task "i2v"

examples/training/finetune/wan_i2v_14b_480p/crush_smol/validation.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,26 @@
33
{
44
"caption": "A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.",
55
"image_path": null,
6-
"video_path": "examples/training/finetune/wan_i2v_14b_480p/crush_smol/validation_dataset/yYcK4nANZz4-Scene-034.mp4",
7-
"num_inference_steps": 50,
6+
"video_path": "validation_dataset/yYcK4nANZz4-Scene-034.mp4",
7+
"num_inference_steps": 40,
88
"height": 480,
99
"width": 832,
1010
"num_frames": 77
1111
},
1212
{
1313
"caption": "A large metal cylinder is seen compressing colorful clay into a compact shape, demonstrating the power of a hydraulic press.",
1414
"image_path": null,
15-
"video_path": "examples/training/finetune/wan_i2v_14b_480p/crush_smol/validation_dataset/yYcK4nANZz4-Scene-027.mp4",
16-
"num_inference_steps": 50,
15+
"video_path": "validation_dataset/yYcK4nANZz4-Scene-027.mp4",
16+
"num_inference_steps": 40,
1717
"height": 480,
1818
"width": 832,
1919
"num_frames": 77
2020
},
2121
{
2222
"caption": "A large metal cylinder is seen pressing down on a pile of colorful candies, flattening them as if they were under a hydraulic press. The candies are crushed and broken into small pieces, creating a mess on the table.",
2323
"image_path": null,
24-
"video_path": "examples/training/finetune/wan_i2v_14b_480p/crush_smol/validation_dataset/yYcK4nANZz4-Scene-030.mp4",
25-
"num_inference_steps": 50,
24+
"video_path": "validation_dataset/yYcK4nANZz4-Scene-030.mp4",
25+
"num_inference_steps": 40,
2626
"height": 480,
2727
"width": 832,
2828
"num_frames": 77
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/bin/bash
2+
3+
python scripts/huggingface/download_hf.py --repo_id "wlsaidhi/crush-smol-merged" --local_dir "data/crush-smol" --repo_type "dataset"
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#!/bin/bash
2+
3+
export WANDB_BASE_URL="https://api.wandb.ai"
4+
export WANDB_MODE=online
5+
# export FASTVIDEO_ATTENTION_BACKEND=TORCH_SDPA
6+
7+
MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
8+
DATA_DIR="data/crush-smol_processed_t2v/combined_parquet_dataset/"
9+
VALIDATION_DIR="data/crush-smol_processed_t2v/validation_parquet_dataset/"
10+
NUM_GPUS=4
11+
# export CUDA_VISIBLE_DEVICES=4,5
12+
13+
14+
# Training arguments
15+
training_args=(
16+
--tracker_project_name "wan_t2v_finetune"
17+
--output_dir "outputs/wan_t2v_finetune"
18+
--max_train_steps 5000
19+
--train_batch_size 1
20+
--train_sp_batch_size 1
21+
--gradient_accumulation_steps 8
22+
--num_latent_t 8
23+
--num_height 480
24+
--num_width 832
25+
--num_frames 77
26+
)
27+
28+
# Parallel arguments
29+
parallel_args=(
30+
--num_gpus $NUM_GPUS
31+
--sp_size $NUM_GPUS
32+
--tp_size $NUM_GPUS
33+
--hsdp_replicate_dim 1
34+
--hsdp_shard_dim $NUM_GPUS
35+
)
36+
37+
# Model arguments
38+
model_args=(
39+
--model_path $MODEL_PATH
40+
--pretrained_model_name_or_path $MODEL_PATH
41+
)
42+
43+
# Dataset arguments
44+
dataset_args=(
45+
--data_path $DATA_DIR
46+
--dataloader_num_workers 1
47+
)
48+
49+
# Validation arguments
50+
validation_args=(
51+
--log_validation
52+
--validation_preprocessed_path $VALIDATION_DIR
53+
--validation_steps 50
54+
--validation_sampling_steps "50"
55+
--validation_guidance_scale "1.0"
56+
)
57+
58+
# Optimizer arguments
59+
optimizer_args=(
60+
--learning_rate 5e-5
61+
--mixed_precision "bf16"
62+
--checkpointing_steps 6000
63+
--weight_decay 1e-4
64+
--max_grad_norm 1.0
65+
)
66+
67+
# Miscellaneous arguments
68+
miscellaneous_args=(
69+
--inference_mode False
70+
--allow_tf32
71+
--checkpoints_total_limit 3
72+
--cfg 0.0
73+
--multi_phased_distill_schedule "4000-1"
74+
--not_apply_cfg_solver
75+
--dit_precision "fp32"
76+
--num_euler_timesteps 50
77+
--ema_start_step 0
78+
)
79+
80+
torchrun \
81+
--nnodes 1 \
82+
--nproc_per_node $NUM_GPUS \
83+
fastvideo/v1/training/wan_training_pipeline.py \
84+
"${parallel_args[@]}" \
85+
"${model_args[@]}" \
86+
"${dataset_args[@]}" \
87+
"${training_args[@]}" \
88+
"${optimizer_args[@]}" \
89+
"${validation_args[@]}" \
90+
"${miscellaneous_args[@]}"

0 commit comments

Comments
 (0)