|
| 1 | +--- |
| 2 | +description: "Comprehensive guide for fine-tuning pretrained models with automatic parallelism and advanced configuration options" |
| 3 | +categories: ["tutorials", "automodel"] |
| 4 | +tags: ["training", "recipe", "automodel", "wan", "advanced"] |
| 5 | +personas: ["mle-focused", "data-scientist-focused"] |
| 6 | +difficulty: "intermediate" |
| 7 | +content_type: "tutorial" |
| 8 | +--- |
| 9 | + |
| 10 | +(tutorial-fine-tuning-pretrained-models)= |
| 11 | + |
| 12 | +# Fine-Tuning Pretrained Models |
| 13 | + |
| 14 | +Comprehensive guide for fine-tuning pretrained diffusion models with automatic parallelism and distributed training support. This approach uses NeMo Automodel backend, which handles parallelism automatically—ideal for quick prototyping and fine-tuning workflows. |
| 15 | + |
| 16 | +**Currently Supported:** WAN 2.1 Text-to-Video (1.3B and 14B models) |
| 17 | + |
| 18 | +:::{note} |
| 19 | +For a quick start guide, see [Automodel Workflow](../get-started/automodel.md). This tutorial provides detailed configuration options and advanced topics. |
| 20 | +::: |
| 21 | + |
| 22 | +--- |
| 23 | + |
| 24 | +## Quick Start |
| 25 | + |
| 26 | +### 1. Docker Setup |
| 27 | + |
| 28 | +```bash |
| 29 | +# Build image |
| 30 | +docker build -f docker/Dockerfile.ci -t dfm-training . |
| 31 | + |
| 32 | +# Run container |
| 33 | +docker run --gpus all -it \ |
| 34 | + -v $(pwd):/workspace \ |
| 35 | + -v /path/to/data:/data \ |
| 36 | + --ipc=host \ |
| 37 | + --ulimit memlock=-1 \ |
| 38 | + --ulimit stack=67108864 \ |
| 39 | + dfm-training bash |
| 40 | + |
| 41 | +# Inside container: Initialize submodules |
| 42 | +export UV_PROJECT_ENVIRONMENT= |
| 43 | +git submodule update --init --recursive 3rdparty/ |
| 44 | +``` |
| 45 | + |
| 46 | +### 2. Prepare Data |
| 47 | + |
| 48 | +We provide two ways to prepare your dataset: |
| 49 | + |
| 50 | +- Start with raw videos: Place your `.mp4` files in a folder and use our data-preparation scripts to scan the videos and generate a `meta.json` entry for each sample (which includes `width`, `height`, `start_frame`, `end_frame`, and a caption). If you have captions, you can also include per-video named `<video>.jsonl`; the scripts will pick up the text automatically. The final dataset layout is shown below. |
| 51 | +- Bring your own `meta.json`: If you already have annotations, create `meta.json` yourself following the schema shown below. |
| 52 | + |
| 53 | +**Create video dataset:** |
| 54 | +In the following example we use two video files, solely for demonstration purposes. Actual training datasets will have a large number of files. |
| 55 | +``` |
| 56 | +<your_video_folder>/ |
| 57 | +├── video1.mp4 |
| 58 | +├── video2.mp4 |
| 59 | +└── meta.json |
| 60 | +``` |
| 61 | + |
| 62 | +**meta.json format:** |
| 63 | +```json |
| 64 | +[ |
| 65 | + { |
| 66 | + "file_name": "video1.mp4", |
| 67 | + "width": 1280, |
| 68 | + "height": 720, |
| 69 | + "start_frame": 0, |
| 70 | + "end_frame": 121, |
| 71 | + "vila_caption": "A detailed description of the video1.mp4 contents..." |
| 72 | + }, |
| 73 | + { |
| 74 | + "file_name": "video2.mp4", |
| 75 | + "width": 1280, |
| 76 | + "height": 720, |
| 77 | + "start_frame": 0, |
| 78 | + "end_frame": 12, |
| 79 | + "vila_caption": "A detailed description of the video2.mp4 contents..." |
| 80 | + } |
| 81 | +] |
| 82 | +``` |
| 83 | + |
| 84 | +**Preprocess videos to .meta files:** |
| 85 | + |
| 86 | +The preprocessing script converts each source video into a single `.meta` file that preserves the full temporal sequence as latents. Training can sample temporal windows/clips from the sequence on the fly. |
| 87 | + |
| 88 | +```bash |
| 89 | +python dfm/src/automodel/utils/data/preprocess_resize.py \ |
| 90 | + --video_folder <your_video_folder> \ |
| 91 | + --output_folder ./processed_meta \ |
| 92 | + --model Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ |
| 93 | + --height 480 \ |
| 94 | + --width 720 \ |
| 95 | + --center-crop |
| 96 | +``` |
| 97 | + |
| 98 | +**Key arguments:** |
| 99 | +- `--video_folder`: Path to folder containing videos and `meta.json` |
| 100 | +- `--output_folder`: Path where `.meta` files will be saved |
| 101 | +- `--model`: Wan2.1 model ID (default: `Wan-AI/Wan2.1-T2V-14B-Diffusers`) |
| 102 | +- `--height/--width`: Target resolution (both must be specified together) |
| 103 | +- `--center-crop`: Crop to exact size after aspect-preserving resize |
| 104 | +- `--device`: Device to use (`cuda` or `cpu`, default: `cuda` if available) |
| 105 | +- `--stochastic`: Use stochastic encoding instead of deterministic (may cause flares) |
| 106 | +- `--no-memory-optimization`: Disable Wan's built-in memory optimization |
| 107 | + |
| 108 | +**Output:** Creates `.meta` files containing: |
| 109 | +- Encoded video latents (normalized) |
| 110 | +- Text embeddings (from UMT5) |
| 111 | +- First frame as JPEG |
| 112 | +- Metadata |
| 113 | + |
| 114 | +### 3. Train |
| 115 | + |
| 116 | +**Single-node (8 GPUs):** |
| 117 | +```bash |
| 118 | +export UV_PROJECT_ENVIRONMENT= |
| 119 | + |
| 120 | +uv run --group automodel --with . \ |
| 121 | + torchrun --nproc-per-node=8 \ |
| 122 | + examples/automodel/finetune/finetune.py \ |
| 123 | + -c examples/automodel/finetune/wan2_1_t2v_flow.yaml |
| 124 | +``` |
| 125 | + |
| 126 | +**Multi-node with SLURM:** |
| 127 | +```bash |
| 128 | +#!/bin/bash |
| 129 | +#SBATCH -N 2 |
| 130 | +#SBATCH --ntasks-per-node 1 |
| 131 | +#SBATCH --gpus-per-node=8 |
| 132 | +#SBATCH --exclusive |
| 133 | + |
| 134 | +export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) |
| 135 | +export MASTER_PORT=29500 |
| 136 | +export NUM_GPUS=8 |
| 137 | + |
| 138 | +# Per-rank UV cache to avoid conflicts |
| 139 | +unset UV_PROJECT_ENVIRONMENT |
| 140 | +mkdir -p /opt/uv_cache/${SLURM_JOB_ID}_${SLURM_PROCID} |
| 141 | +export UV_CACHE_DIR=/opt/uv_cache/${SLURM_JOB_ID}_${SLURM_PROCID} |
| 142 | + |
| 143 | +uv run --group automodel --with . \ |
| 144 | + torchrun \ |
| 145 | + --nnodes=$SLURM_NNODES \ |
| 146 | + --nproc-per-node=$NUM_GPUS \ |
| 147 | + --rdzv_backend=c10d \ |
| 148 | + --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ |
| 149 | + examples/automodel/finetune/finetune.py \ |
| 150 | + -c examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml |
| 151 | +``` |
| 152 | + |
| 153 | +### 4. Validate |
| 154 | + |
| 155 | +Use this step to perform a quick qualitative check of a trained checkpoint. The validation script: |
| 156 | +- Reads prompts from `.meta` files in `--meta_folder` (uses `metadata.vila_caption`; latents are ignored). |
| 157 | +- Loads the `WanPipeline` and, if provided, restores weights from `--checkpoint` (prefers `ema_shadow.pt`, then `consolidated_model.bin`, then sharded FSDP `model/*.distcp`). |
| 158 | +- Generates short videos for each prompt with the specified settings (`--guidance_scale`, `--num_inference_steps`, `--height/--width`, `--num_frames`, `--fps`, `--seed`) and writes them to `--output_dir`. |
| 159 | +- Intended for qualitative comparison across checkpoints; it does not compute quantitative metrics. |
| 160 | + |
| 161 | +```bash |
| 162 | +uv run --group automodel --with . \ |
| 163 | + python examples/automodel/generate/wan_validate.py \ |
| 164 | + --meta_folder <your_meta_folder> \ |
| 165 | + --guidance_scale 5 \ |
| 166 | + --checkpoint ./checkpoints/step_1000 \ |
| 167 | + --num_samples 10 |
| 168 | +``` |
| 169 | + |
| 170 | +**Note:** You can use `--checkpoint ./checkpoints/LATEST` to automatically use the most recent checkpoint. |
| 171 | + |
| 172 | +--- |
| 173 | + |
| 174 | +## Configuration |
| 175 | + |
| 176 | +### Fine-tuning Config (`wan2_1_t2v_flow.yaml`) |
| 177 | + |
| 178 | +Note: The inline configuration below is provided for quick reference. The canonical, up-to-date files are maintained in the repository: [examples/automodel/](../../examples/automodel/), [examples/automodel/finetune/wan2_1_t2v_flow.yaml](../../examples/automodel/finetune/wan2_1_t2v_flow.yaml), and [examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml](../../examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml). |
| 179 | + |
| 180 | +```yaml |
| 181 | +model: # Base pretrained model to fine-tune |
| 182 | + pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers # HF repo or local path |
| 183 | + |
| 184 | +step_scheduler: # Global training schedule |
| 185 | + global_batch_size: 8 # Effective batch size across all GPUs |
| 186 | + local_batch_size: 1 # Per-GPU batch size |
| 187 | + num_epochs: 10 # Number of passes over the dataset |
| 188 | + ckpt_every_steps: 100 # Save a checkpoint every N steps |
| 189 | + |
| 190 | +data: # Data input configuration |
| 191 | + dataloader: # DataLoader parameters |
| 192 | + meta_folder: "<your_processed_meta_folder>" # Folder containing .meta files |
| 193 | + num_workers: 2 # Worker processes per rank |
| 194 | + |
| 195 | +optim: # Optimizer/training hyperparameters |
| 196 | + learning_rate: 5e-6 # Base learning rate |
| 197 | + |
| 198 | +flow_matching: # Flow-matching training settings |
| 199 | + timestep_sampling: "uniform" # Strategy for sampling timesteps |
| 200 | + flow_shift: 3.0 # Scalar shift applied to the target flow |
| 201 | + |
| 202 | +fsdp: # Distributed training (e.g., FSDP) configuration |
| 203 | + dp_size: 8 # Total data-parallel replicas (single node: 8 GPUs) |
| 204 | + |
| 205 | +checkpoint: # Checkpointing behavior |
| 206 | + enabled: true # Enable periodic checkpoint saving |
| 207 | + checkpoint_dir: "./checkpoints" # Output directory for checkpoints |
| 208 | +``` |
| 209 | +
|
| 210 | +### Multi-node Config Differences |
| 211 | +
|
| 212 | +```yaml |
| 213 | +fsdp: # Overrides for multi-node runs |
| 214 | + dp_size: 16 # Total data-parallel replicas (2 nodes × 8 GPUs) |
| 215 | + dp_replicate_size: 2 # Number of replicated groups across nodes |
| 216 | +``` |
| 217 | +
|
| 218 | +### Pretraining vs Fine-tuning |
| 219 | +
|
| 220 | +| Setting | Fine-tuning | Pretraining | |
| 221 | +|---------|-------------|-------------| |
| 222 | +| `learning_rate` | 5e-6 | 5e-5 | |
| 223 | +| `weight_decay` | 0.01 | 0.1 | |
| 224 | +| `flow_shift` | 3.0 | 2.5 | |
| 225 | +| `logit_std` | 1.0 | 1.5 | |
| 226 | +| Dataset size | 100s-1000s | 10K+ | |
| 227 | + |
| 228 | +--- |
| 229 | + |
| 230 | +## Hardware Requirements |
| 231 | + |
| 232 | +| Component | Minimum | Recommended | |
| 233 | +|-----------|---------|-------------| |
| 234 | +| GPU | A100 40GB | A100 80GB / H100 | |
| 235 | +| GPUs | 4 | 8+ | |
| 236 | +| RAM | 128 GB | 256 GB+ | |
| 237 | +| Storage | 500 GB SSD | 2 TB NVMe | |
| 238 | + |
| 239 | +--- |
| 240 | + |
| 241 | +## Features |
| 242 | + |
| 243 | +- ✅ **Flow Matching**: Pure flow matching training |
| 244 | +- ✅ **Distributed**: FSDP2 + Tensor Parallelism |
| 245 | +- ✅ **Mixed Precision**: BF16 by default |
| 246 | +- ✅ **WandB**: Automatic logging |
| 247 | +- ✅ **Checkpointing**: consolidated, and sharded formats |
| 248 | +- ✅ **Multi-node**: SLURM and torchrun support |
| 249 | + |
| 250 | +--- |
| 251 | + |
| 252 | +## Supported Models |
| 253 | + |
| 254 | +| Model | Parameters | Parallelization | Status | |
| 255 | +|-------|------------|-----------------|--------| |
| 256 | +| Wan 2.1 T2V 1.3B | 1.3B | FSDP2 via Automodel + DDP | ✅ | |
| 257 | +| Wan 2.1 T2V 14B | 14B | FSDP2 via Automodel + DDP | ✅ | |
| 258 | +| FLUX | TBD | TBD | 🔄 In Progress | |
| 259 | + |
| 260 | +--- |
| 261 | + |
| 262 | +## Advanced |
| 263 | + |
| 264 | +**Custom parallelization:** |
| 265 | +```yaml |
| 266 | +fsdp: |
| 267 | + tp_size: 2 # Tensor parallel |
| 268 | + dp_size: 4 # Data parallel |
| 269 | +``` |
| 270 | + |
| 271 | +**Checkpoint cleanup:** |
| 272 | +```python |
| 273 | +from pathlib import Path |
| 274 | +import shutil |
| 275 | +
|
| 276 | +def cleanup_old_checkpoints(checkpoint_dir, keep_last_n=3): |
| 277 | + checkpoints = sorted(Path(checkpoint_dir).glob("step_*")) |
| 278 | + for old_ckpt in checkpoints[:-keep_last_n]: |
| 279 | + shutil.rmtree(old_ckpt) |
| 280 | +``` |
| 281 | + |
| 282 | +--- |
| 283 | + |
| 284 | +## Related Documentation |
| 285 | + |
| 286 | +- [Automodel Quick Start](../get-started/automodel.md) - Quick start guide |
| 287 | +- [Training Paradigms](../about/concepts/training-paradigms.md) - Understanding Automodel vs Megatron approaches |
| 288 | +- [Performance Benchmarks](../reference/performance.md) - Training throughput metrics |
| 289 | + |
0 commit comments