Skip to content

Commit 407d986

Browse files
authored
update (#69)
Signed-off-by: linnan wang <[email protected]>
1 parent 2eb57c2 commit 407d986

File tree

1 file changed

+288
-0
lines changed

1 file changed

+288
-0
lines changed
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
# Diffusion Model Fine-tuning with Automodel Backend
2+
3+
Train diffusion models with distributed training support using NeMo Automodel and flow matching.
4+
5+
**Currently Supported:** Wan 2.1 Text-to-Video (1.3B and 14B models)
6+
7+
---
8+
9+
## Quick Start
10+
11+
### 1. Docker Setup
12+
13+
```bash
14+
# Build image
15+
docker build -f docker/Dockerfile.ci -t dfm-training .
16+
17+
# Run container
18+
docker run --gpus all -it \
19+
-v $(pwd):/workspace \
20+
-v /path/to/data:/data \
21+
--ipc=host \
22+
--ulimit memlock=-1 \
23+
--ulimit stack=67108864 \
24+
dfm-training bash
25+
26+
# Inside container: Initialize submodules
27+
export UV_PROJECT_ENVIRONMENT=
28+
git submodule update --init --recursive 3rdparty/
29+
```
30+
31+
### 2. Prepare Data
32+
33+
We provide two ways to prepare your dataset:
34+
35+
- Start with raw videos: Place your `.mp4` files in a folder and use our data-preparation scripts to scan the videos and generate a `meta.json` entry for each sample (which includes `width`, `height`, `start_frame`, `end_frame`, and a caption). If you have captions, you can also include per-video named `<video>.jsonl`; the scripts will pick up the text automatically. The final dataset layout is shown below.
36+
- Bring your own `meta.json`: If you already have annotations, create `meta.json` yourself following the schema shown below.
37+
38+
**Create video dataset:**
39+
In the following exaample we use two video files, solely for demonstration purposes. Actual training datasets will have a large number of files.
40+
```
41+
<your_video_folder>/
42+
├── video1.mp4
43+
├── video2.mp4
44+
└── meta.json
45+
```
46+
47+
**meta.json format:**
48+
```json
49+
[
50+
{
51+
"file_name": "video1.mp4",
52+
"width": 1280,
53+
"height": 720,
54+
"start_frame": 0,
55+
"end_frame": 121,
56+
"vila_caption": "A detailed description of the video1.mp4 contents..."
57+
},
58+
{
59+
"file_name": "video2.mp4",
60+
"width": 1280,
61+
"height": 720,
62+
"start_frame": 0,
63+
"end_frame": 12,
64+
"vila_caption": "A detailed description of the video2.mp4 contents..."
65+
}
66+
]
67+
```
68+
69+
**Preprocess videos to .meta files:**
70+
71+
There are two preprocessing modes. Use this guide to choose the right mode:
72+
73+
- **Full Video (`--mode video`)**
74+
- **What it is**: Converts each source video into a single `.meta` that preserves the full temporal sequence as latents. Training can sample temporal windows/clips from the sequence on the fly.
75+
- **When to use**: Fine-tuning text-to-video models where motion and temporal consistency matter. This is the recommended default for most training runs.
76+
77+
- **Extract frames (`--mode frames`)**
78+
- **What it is**: Uniformly samples `N` frames per video and writes each as its own one-frame `.meta` sample (no temporal continuity).
79+
- **When to use**: Image/frame-level training objectives, quick smoke tests, or ablations where learning motion is not required.
80+
81+
**Mode 1: Full video (recommended for training)**
82+
```bash
83+
python dfm/src/automodel/utils/data/preprocess_resize.py \
84+
--mode video \
85+
--video_folder <your_video_folder> \
86+
--output_folder ./processed_meta \
87+
--model Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
88+
--height 480 \
89+
--width 720 \
90+
--center-crop
91+
```
92+
93+
**Mode 2: Extract frames (for frame-based training)**
94+
```bash
95+
python dfm/src/automodel/utils/data/preprocess_resize.py \
96+
--mode frames \
97+
--num-frames 40 \
98+
--video_folder <your_video_folder> \
99+
--output_folder ./processed_frames \
100+
--model Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
101+
--height 240 \
102+
--width 416 \
103+
--center-crop
104+
```
105+
106+
**Key arguments:**
107+
- `--mode`: `video` (full video) or `frames` (extract evenly-spaced frames)
108+
- `--num-frames`: Number of frames to extract (only for `frames` mode)
109+
- `--height/--width`: Target resolution
110+
- `--center-crop`: Crop to exact size after aspect-preserving resize
111+
112+
**Preprocessing modes:**
113+
- **`video` mode**: Processes entire video sequence, creates one `.meta` file per video
114+
- **`frames` mode**: Extracts N evenly-spaced frames, creates one `.meta` file per frame (treated as 1-frame videos)
115+
116+
**Output:** Creates `.meta` files containing:
117+
- Encoded video latents (normalized)
118+
- Text embeddings (from UMT5)
119+
- First frame as JPEG (video mode only)
120+
- Metadata
121+
122+
### 3. Train
123+
124+
**Single-node (8 GPUs):**
125+
```bash
126+
export UV_PROJECT_ENVIRONMENT=
127+
128+
uv run --group automodel --with . \
129+
torchrun --nproc-per-node=8 \
130+
examples/automodel/finetune/finetune.py \
131+
-c examples/automodel/finetune/wan2_1_t2v_flow.yaml
132+
```
133+
134+
**Multi-node with SLURM:**
135+
```bash
136+
#!/bin/bash
137+
#SBATCH -N 2
138+
#SBATCH --ntasks-per-node 1
139+
#SBATCH --gpus-per-node=8
140+
#SBATCH --exclusive
141+
142+
export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
143+
export MASTER_PORT=29500
144+
export NUM_GPUS=8
145+
146+
# Per-rank UV cache to avoid conflicts
147+
unset UV_PROJECT_ENVIRONMENT
148+
mkdir -p /opt/uv_cache/${SLURM_JOB_ID}_${SLURM_PROCID}
149+
export UV_CACHE_DIR=/opt/uv_cache/${SLURM_JOB_ID}_${SLURM_PROCID}
150+
151+
uv run --group automodel --with . \
152+
torchrun \
153+
--nnodes=$SLURM_NNODES \
154+
--nproc-per-node=$NUM_GPUS \
155+
--rdzv_backend=c10d \
156+
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
157+
examples/automodel/finetune/finetune.py \
158+
-c examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml
159+
```
160+
161+
### 4. Validate
162+
163+
Use this step to perform a quick qualitative check of a trained checkpoint. The validation script:
164+
- Reads prompts from `.meta` files in `--meta_folder` (uses `metadata.vila_caption`; latents are ignored).
165+
- Loads the `WanPipeline` and, if provided, restores weights from `--checkpoint` (prefers `ema_shadow.pt`, then `consolidated_model.bin`, then sharded FSDP `model/*.distcp`).
166+
- Generates short videos for each prompt with the specified settings (`--guidance_scale`, `--num_inference_steps`, `--height/--width`, `--num_frames`, `--fps`, `--seed`) and writes them to `--output_dir`.
167+
- Intended for qualitative comparison across checkpoints; it does not compute quantitative metrics.
168+
169+
```bash
170+
uv run --group automodel --with . \
171+
python examples/automodel/generate/wan_validate.py \
172+
--meta_folder <your_meta_folder> \
173+
--guidance_scale 5 \
174+
--checkpoint ./checkpoints/step_1000 \
175+
--num_samples 10
176+
```
177+
178+
**Note:** You can use `--checkpoint ./checkpoints/LATEST` to automatically use the most recent checkpoint.
179+
180+
---
181+
182+
## Configuration
183+
184+
### Fine-tuning Config (`wan2_1_t2v_flow.yaml`)
185+
186+
Note: The inline configuration below is provided for quick reference. The canonical, up-to-date files are maintained in the repository: [examples/automodel/](../../examples/automodel/), [examples/automodel/finetune/wan2_1_t2v_flow.yaml](../../examples/automodel/finetune/wan2_1_t2v_flow.yaml), and [examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml](../../examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml).
187+
188+
```yaml
189+
model: # Base pretrained model to fine-tune
190+
pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers # HF repo or local path
191+
192+
step_scheduler: # Global training schedule
193+
global_batch_size: 8 # Effective batch size across all GPUs
194+
local_batch_size: 1 # Per-GPU batch size
195+
num_epochs: 10 # Number of passes over the dataset
196+
ckpt_every_steps: 100 # Save a checkpoint every N steps
197+
198+
data: # Data input configuration
199+
dataloader: # DataLoader parameters
200+
meta_folder: "<your_processed_meta_folder>" # Folder containing .meta files
201+
num_workers: 2 # Worker processes per rank
202+
203+
optim: # Optimizer/training hyperparameters
204+
learning_rate: 5e-6 # Base learning rate
205+
206+
flow_matching: # Flow-matching training settings
207+
timestep_sampling: "uniform" # Strategy for sampling timesteps
208+
flow_shift: 3.0 # Scalar shift applied to the target flow
209+
210+
fsdp: # Distributed training (e.g., FSDP) configuration
211+
dp_size: 8 # Total data-parallel replicas (single node: 8 GPUs)
212+
213+
checkpoint: # Checkpointing behavior
214+
enabled: true # Enable periodic checkpoint saving
215+
checkpoint_dir: "./checkpoints" # Output directory for checkpoints
216+
```
217+
218+
### Multi-node Config Differences
219+
220+
```yaml
221+
fsdp: # Overrides for multi-node runs
222+
dp_size: 16 # Total data-parallel replicas (2 nodes × 8 GPUs)
223+
dp_replicate_size: 2 # Number of replicated groups across nodes
224+
```
225+
226+
### Pretraining vs Fine-tuning
227+
228+
| Setting | Fine-tuning | Pretraining |
229+
|---------|-------------|-------------|
230+
| `learning_rate` | 5e-6 | 5e-5 |
231+
| `weight_decay` | 0.01 | 0.1 |
232+
| `flow_shift` | 3.0 | 2.5 |
233+
| `logit_std` | 1.0 | 1.5 |
234+
| Dataset size | 100s-1000s | 10K+ |
235+
236+
---
237+
238+
## Hardware Requirements
239+
240+
| Component | Minimum | Recommended |
241+
|-----------|---------|-------------|
242+
| GPU | A100 40GB | A100 80GB / H100 |
243+
| GPUs | 4 | 8+ |
244+
| RAM | 128 GB | 256 GB+ |
245+
| Storage | 500 GB SSD | 2 TB NVMe |
246+
247+
---
248+
249+
## Features
250+
251+
- ✅ **Flow Matching**: Pure flow matching training
252+
- ✅ **Distributed**: FSDP2 + Tensor Parallelism
253+
- ✅ **Mixed Precision**: BF16 by default
254+
- ✅ **WandB**: Automatic logging
255+
- ✅ **Checkpointing**: consolidated, and sharded formats
256+
- ✅ **Multi-node**: SLURM and torchrun support
257+
258+
---
259+
260+
## Supported Models
261+
262+
| Model | Parameters | Parallelization | Status |
263+
|-------|------------|-----------------|--------|
264+
| Wan 2.1 T2V 1.3B | 1.3B | FSDP2 via Automodel + DDP | ✅ |
265+
| Wan 2.1 T2V 14B | 14B | FSDP2 via Automodel + DDP | ✅ |
266+
| FLUX | TBD | TBD | 🔄 In Progress |
267+
268+
---
269+
270+
## Advanced
271+
272+
**Custom parallelization:**
273+
```yaml
274+
fsdp:
275+
tp_size: 2 # Tensor parallel
276+
dp_size: 4 # Data parallel
277+
```
278+
279+
**Checkpoint cleanup:**
280+
```python
281+
from pathlib import Path
282+
import shutil
283+
284+
def cleanup_old_checkpoints(checkpoint_dir, keep_last_n=3):
285+
checkpoints = sorted(Path(checkpoint_dir).glob("step_*"))
286+
for old_ckpt in checkpoints[:-keep_last_n]:
287+
shutil.rmtree(old_ckpt)
288+
```

0 commit comments

Comments
 (0)