|
| 1 | +# Diffusion Model Fine-tuning with Automodel Backend |
| 2 | + |
| 3 | +Train diffusion models with distributed training support using NeMo Automodel and flow matching. |
| 4 | + |
| 5 | +**Currently Supported:** Wan 2.1 Text-to-Video (1.3B and 14B models) |
| 6 | + |
| 7 | +--- |
| 8 | + |
| 9 | +## Quick Start |
| 10 | + |
| 11 | +### 1. Docker Setup |
| 12 | + |
| 13 | +```bash |
| 14 | +# Build image |
| 15 | +docker build -f docker/Dockerfile.ci -t dfm-training . |
| 16 | + |
| 17 | +# Run container |
| 18 | +docker run --gpus all -it \ |
| 19 | + -v $(pwd):/workspace \ |
| 20 | + -v /path/to/data:/data \ |
| 21 | + --ipc=host \ |
| 22 | + --ulimit memlock=-1 \ |
| 23 | + --ulimit stack=67108864 \ |
| 24 | + dfm-training bash |
| 25 | + |
| 26 | +# Inside container: Initialize submodules |
| 27 | +export UV_PROJECT_ENVIRONMENT= |
| 28 | +git submodule update --init --recursive 3rdparty/ |
| 29 | +``` |
| 30 | + |
| 31 | +### 2. Prepare Data |
| 32 | + |
| 33 | +We provide two ways to prepare your dataset: |
| 34 | + |
| 35 | +- Start with raw videos: Place your `.mp4` files in a folder and use our data-preparation scripts to scan the videos and generate a `meta.json` entry for each sample (which includes `width`, `height`, `start_frame`, `end_frame`, and a caption). If you have captions, you can also include per-video named `<video>.jsonl`; the scripts will pick up the text automatically. The final dataset layout is shown below. |
| 36 | +- Bring your own `meta.json`: If you already have annotations, create `meta.json` yourself following the schema shown below. |
| 37 | + |
| 38 | +**Create video dataset:** |
| 39 | +In the following exaample we use two video files, solely for demonstration purposes. Actual training datasets will have a large number of files. |
| 40 | +``` |
| 41 | +<your_video_folder>/ |
| 42 | +├── video1.mp4 |
| 43 | +├── video2.mp4 |
| 44 | +└── meta.json |
| 45 | +``` |
| 46 | + |
| 47 | +**meta.json format:** |
| 48 | +```json |
| 49 | +[ |
| 50 | + { |
| 51 | + "file_name": "video1.mp4", |
| 52 | + "width": 1280, |
| 53 | + "height": 720, |
| 54 | + "start_frame": 0, |
| 55 | + "end_frame": 121, |
| 56 | + "vila_caption": "A detailed description of the video1.mp4 contents..." |
| 57 | + }, |
| 58 | + { |
| 59 | + "file_name": "video2.mp4", |
| 60 | + "width": 1280, |
| 61 | + "height": 720, |
| 62 | + "start_frame": 0, |
| 63 | + "end_frame": 12, |
| 64 | + "vila_caption": "A detailed description of the video2.mp4 contents..." |
| 65 | + } |
| 66 | +] |
| 67 | +``` |
| 68 | + |
| 69 | +**Preprocess videos to .meta files:** |
| 70 | + |
| 71 | +There are two preprocessing modes. Use this guide to choose the right mode: |
| 72 | + |
| 73 | +- **Full Video (`--mode video`)** |
| 74 | + - **What it is**: Converts each source video into a single `.meta` that preserves the full temporal sequence as latents. Training can sample temporal windows/clips from the sequence on the fly. |
| 75 | + - **When to use**: Fine-tuning text-to-video models where motion and temporal consistency matter. This is the recommended default for most training runs. |
| 76 | + |
| 77 | +- **Extract frames (`--mode frames`)** |
| 78 | + - **What it is**: Uniformly samples `N` frames per video and writes each as its own one-frame `.meta` sample (no temporal continuity). |
| 79 | + - **When to use**: Image/frame-level training objectives, quick smoke tests, or ablations where learning motion is not required. |
| 80 | + |
| 81 | +**Mode 1: Full video (recommended for training)** |
| 82 | +```bash |
| 83 | +python dfm/src/automodel/utils/data/preprocess_resize.py \ |
| 84 | + --mode video \ |
| 85 | + --video_folder <your_video_folder> \ |
| 86 | + --output_folder ./processed_meta \ |
| 87 | + --model Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ |
| 88 | + --height 480 \ |
| 89 | + --width 720 \ |
| 90 | + --center-crop |
| 91 | +``` |
| 92 | + |
| 93 | +**Mode 2: Extract frames (for frame-based training)** |
| 94 | +```bash |
| 95 | +python dfm/src/automodel/utils/data/preprocess_resize.py \ |
| 96 | + --mode frames \ |
| 97 | + --num-frames 40 \ |
| 98 | + --video_folder <your_video_folder> \ |
| 99 | + --output_folder ./processed_frames \ |
| 100 | + --model Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ |
| 101 | + --height 240 \ |
| 102 | + --width 416 \ |
| 103 | + --center-crop |
| 104 | +``` |
| 105 | + |
| 106 | +**Key arguments:** |
| 107 | +- `--mode`: `video` (full video) or `frames` (extract evenly-spaced frames) |
| 108 | +- `--num-frames`: Number of frames to extract (only for `frames` mode) |
| 109 | +- `--height/--width`: Target resolution |
| 110 | +- `--center-crop`: Crop to exact size after aspect-preserving resize |
| 111 | + |
| 112 | +**Preprocessing modes:** |
| 113 | +- **`video` mode**: Processes entire video sequence, creates one `.meta` file per video |
| 114 | +- **`frames` mode**: Extracts N evenly-spaced frames, creates one `.meta` file per frame (treated as 1-frame videos) |
| 115 | + |
| 116 | +**Output:** Creates `.meta` files containing: |
| 117 | +- Encoded video latents (normalized) |
| 118 | +- Text embeddings (from UMT5) |
| 119 | +- First frame as JPEG (video mode only) |
| 120 | +- Metadata |
| 121 | + |
| 122 | +### 3. Train |
| 123 | + |
| 124 | +**Single-node (8 GPUs):** |
| 125 | +```bash |
| 126 | +export UV_PROJECT_ENVIRONMENT= |
| 127 | + |
| 128 | +uv run --group automodel --with . \ |
| 129 | + torchrun --nproc-per-node=8 \ |
| 130 | + examples/automodel/finetune/finetune.py \ |
| 131 | + -c examples/automodel/finetune/wan2_1_t2v_flow.yaml |
| 132 | +``` |
| 133 | + |
| 134 | +**Multi-node with SLURM:** |
| 135 | +```bash |
| 136 | +#!/bin/bash |
| 137 | +#SBATCH -N 2 |
| 138 | +#SBATCH --ntasks-per-node 1 |
| 139 | +#SBATCH --gpus-per-node=8 |
| 140 | +#SBATCH --exclusive |
| 141 | + |
| 142 | +export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) |
| 143 | +export MASTER_PORT=29500 |
| 144 | +export NUM_GPUS=8 |
| 145 | + |
| 146 | +# Per-rank UV cache to avoid conflicts |
| 147 | +unset UV_PROJECT_ENVIRONMENT |
| 148 | +mkdir -p /opt/uv_cache/${SLURM_JOB_ID}_${SLURM_PROCID} |
| 149 | +export UV_CACHE_DIR=/opt/uv_cache/${SLURM_JOB_ID}_${SLURM_PROCID} |
| 150 | + |
| 151 | +uv run --group automodel --with . \ |
| 152 | + torchrun \ |
| 153 | + --nnodes=$SLURM_NNODES \ |
| 154 | + --nproc-per-node=$NUM_GPUS \ |
| 155 | + --rdzv_backend=c10d \ |
| 156 | + --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ |
| 157 | + examples/automodel/finetune/finetune.py \ |
| 158 | + -c examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml |
| 159 | +``` |
| 160 | + |
| 161 | +### 4. Validate |
| 162 | + |
| 163 | +Use this step to perform a quick qualitative check of a trained checkpoint. The validation script: |
| 164 | +- Reads prompts from `.meta` files in `--meta_folder` (uses `metadata.vila_caption`; latents are ignored). |
| 165 | +- Loads the `WanPipeline` and, if provided, restores weights from `--checkpoint` (prefers `ema_shadow.pt`, then `consolidated_model.bin`, then sharded FSDP `model/*.distcp`). |
| 166 | +- Generates short videos for each prompt with the specified settings (`--guidance_scale`, `--num_inference_steps`, `--height/--width`, `--num_frames`, `--fps`, `--seed`) and writes them to `--output_dir`. |
| 167 | +- Intended for qualitative comparison across checkpoints; it does not compute quantitative metrics. |
| 168 | + |
| 169 | +```bash |
| 170 | +uv run --group automodel --with . \ |
| 171 | + python examples/automodel/generate/wan_validate.py \ |
| 172 | + --meta_folder <your_meta_folder> \ |
| 173 | + --guidance_scale 5 \ |
| 174 | + --checkpoint ./checkpoints/step_1000 \ |
| 175 | + --num_samples 10 |
| 176 | +``` |
| 177 | + |
| 178 | +**Note:** You can use `--checkpoint ./checkpoints/LATEST` to automatically use the most recent checkpoint. |
| 179 | + |
| 180 | +--- |
| 181 | + |
| 182 | +## Configuration |
| 183 | + |
| 184 | +### Fine-tuning Config (`wan2_1_t2v_flow.yaml`) |
| 185 | + |
| 186 | +Note: The inline configuration below is provided for quick reference. The canonical, up-to-date files are maintained in the repository: [examples/automodel/](../../examples/automodel/), [examples/automodel/finetune/wan2_1_t2v_flow.yaml](../../examples/automodel/finetune/wan2_1_t2v_flow.yaml), and [examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml](../../examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml). |
| 187 | + |
| 188 | +```yaml |
| 189 | +model: # Base pretrained model to fine-tune |
| 190 | + pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers # HF repo or local path |
| 191 | + |
| 192 | +step_scheduler: # Global training schedule |
| 193 | + global_batch_size: 8 # Effective batch size across all GPUs |
| 194 | + local_batch_size: 1 # Per-GPU batch size |
| 195 | + num_epochs: 10 # Number of passes over the dataset |
| 196 | + ckpt_every_steps: 100 # Save a checkpoint every N steps |
| 197 | + |
| 198 | +data: # Data input configuration |
| 199 | + dataloader: # DataLoader parameters |
| 200 | + meta_folder: "<your_processed_meta_folder>" # Folder containing .meta files |
| 201 | + num_workers: 2 # Worker processes per rank |
| 202 | + |
| 203 | +optim: # Optimizer/training hyperparameters |
| 204 | + learning_rate: 5e-6 # Base learning rate |
| 205 | + |
| 206 | +flow_matching: # Flow-matching training settings |
| 207 | + timestep_sampling: "uniform" # Strategy for sampling timesteps |
| 208 | + flow_shift: 3.0 # Scalar shift applied to the target flow |
| 209 | + |
| 210 | +fsdp: # Distributed training (e.g., FSDP) configuration |
| 211 | + dp_size: 8 # Total data-parallel replicas (single node: 8 GPUs) |
| 212 | + |
| 213 | +checkpoint: # Checkpointing behavior |
| 214 | + enabled: true # Enable periodic checkpoint saving |
| 215 | + checkpoint_dir: "./checkpoints" # Output directory for checkpoints |
| 216 | +``` |
| 217 | +
|
| 218 | +### Multi-node Config Differences |
| 219 | +
|
| 220 | +```yaml |
| 221 | +fsdp: # Overrides for multi-node runs |
| 222 | + dp_size: 16 # Total data-parallel replicas (2 nodes × 8 GPUs) |
| 223 | + dp_replicate_size: 2 # Number of replicated groups across nodes |
| 224 | +``` |
| 225 | +
|
| 226 | +### Pretraining vs Fine-tuning |
| 227 | +
|
| 228 | +| Setting | Fine-tuning | Pretraining | |
| 229 | +|---------|-------------|-------------| |
| 230 | +| `learning_rate` | 5e-6 | 5e-5 | |
| 231 | +| `weight_decay` | 0.01 | 0.1 | |
| 232 | +| `flow_shift` | 3.0 | 2.5 | |
| 233 | +| `logit_std` | 1.0 | 1.5 | |
| 234 | +| Dataset size | 100s-1000s | 10K+ | |
| 235 | + |
| 236 | +--- |
| 237 | + |
| 238 | +## Hardware Requirements |
| 239 | + |
| 240 | +| Component | Minimum | Recommended | |
| 241 | +|-----------|---------|-------------| |
| 242 | +| GPU | A100 40GB | A100 80GB / H100 | |
| 243 | +| GPUs | 4 | 8+ | |
| 244 | +| RAM | 128 GB | 256 GB+ | |
| 245 | +| Storage | 500 GB SSD | 2 TB NVMe | |
| 246 | + |
| 247 | +--- |
| 248 | + |
| 249 | +## Features |
| 250 | + |
| 251 | +- ✅ **Flow Matching**: Pure flow matching training |
| 252 | +- ✅ **Distributed**: FSDP2 + Tensor Parallelism |
| 253 | +- ✅ **Mixed Precision**: BF16 by default |
| 254 | +- ✅ **WandB**: Automatic logging |
| 255 | +- ✅ **Checkpointing**: consolidated, and sharded formats |
| 256 | +- ✅ **Multi-node**: SLURM and torchrun support |
| 257 | + |
| 258 | +--- |
| 259 | + |
| 260 | +## Supported Models |
| 261 | + |
| 262 | +| Model | Parameters | Parallelization | Status | |
| 263 | +|-------|------------|-----------------|--------| |
| 264 | +| Wan 2.1 T2V 1.3B | 1.3B | FSDP2 via Automodel + DDP | ✅ | |
| 265 | +| Wan 2.1 T2V 14B | 14B | FSDP2 via Automodel + DDP | ✅ | |
| 266 | +| FLUX | TBD | TBD | 🔄 In Progress | |
| 267 | + |
| 268 | +--- |
| 269 | + |
| 270 | +## Advanced |
| 271 | + |
| 272 | +**Custom parallelization:** |
| 273 | +```yaml |
| 274 | +fsdp: |
| 275 | + tp_size: 2 # Tensor parallel |
| 276 | + dp_size: 4 # Data parallel |
| 277 | +``` |
| 278 | + |
| 279 | +**Checkpoint cleanup:** |
| 280 | +```python |
| 281 | +from pathlib import Path |
| 282 | +import shutil |
| 283 | +
|
| 284 | +def cleanup_old_checkpoints(checkpoint_dir, keep_last_n=3): |
| 285 | + checkpoints = sorted(Path(checkpoint_dir).glob("step_*")) |
| 286 | + for old_ckpt in checkpoints[:-keep_last_n]: |
| 287 | + shutil.rmtree(old_ckpt) |
| 288 | +``` |
0 commit comments