|
| 1 | +## 🚀 Megatron WAN |
| 2 | + |
| 3 | +### 📋 Overview |
| 4 | +An open-source implementation of [WAN 2.1](https://github.com/Wan-Video/Wan2.1) (large-scale text-to-video/image generative models) built on top of [Megatron-Core](https://github.com/NVIDIA/Megatron-LM) and [Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) for scalable and efficient training. It supports advanced parallelism strategies (data, tensor, sequence, and context parallelism) and optimized kernels (e.g., Transformer Engine fused attention). |
| 5 | + |
| 6 | +--- |
| 7 | + |
| 8 | +### 📦 Dataset Preparation |
| 9 | +This recipe uses NVIDIA's [Megatron-Energon](https://github.com/NVIDIA/Megatron-Energon) as an efficient multi-modal data loader. Datasets should be in the WebDataset-compatible format (typically sharded `.tar` archives). Energon supports large-scale distributed loading, sharding, and sampling for video-text and image-text pairs. Set `dataset.path` to your WebDataset directory or shard pattern. See Megatron-Energon docs for format details, subflavors, and advanced options. |
| 10 | + |
| 11 | +If you do not have a dataset yet or only need to validate performance/plumbing, see the "Quick Start with Mock Dataset" section below. |
| 12 | + |
| 13 | +--- |
| 14 | + |
| 15 | +#### 🗂️ Dataset Preparation Example |
| 16 | +Starting with a directory containing raw .mp4 videos and their corresponding .json metadata files containing captions, you can turn the data into WAN-ready WebDataset shards using our helper script. We then use Energon to process those shards and create its metadata. After this, you can set training script's `dataset.path` argument to the output processed data folder and start training. |
| 17 | + |
| 18 | +```bash |
| 19 | +# 1) Define your input (raw videos) and output (WebDataset shards) folders. For example: |
| 20 | +DATASET_SRC=/opt/raw_videos # contains .mp4 and per-video .jsonl captions |
| 21 | +DATASET_PATH=/opt/wan_webdataset # output WebDataset shards |
| 22 | + |
| 23 | +# 2) (Optional) If your WAN models require auth on first download |
| 24 | +export HF_TOKEN=<your_huggingface_token> |
| 25 | + |
| 26 | +# 3) Create WAN shards with latents + text embeddings |
| 27 | +# Wan's VAE encoder and T5 encoder is used to extract videos' latents and caption embeddings offline before training, using the following core arugments: |
| 28 | +# --height/--width: control resize target (832x480 is supported for both 1.3B and 14B model) |
| 29 | +# --center-crop: run center crop to exact target size after resize |
| 30 | +uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node 1 \ |
| 31 | + examples/megatron/recipes/wan/prepare_energon_dataset_wan.py \ |
| 32 | + --video_folder "${DATASET_SRC}" \ |
| 33 | + --output_dir "${DATASET_PATH}" \ |
| 34 | + --model "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" \ |
| 35 | + --height 480 --width 832 \ |
| 36 | + --center-crop |
| 37 | + |
| 38 | +# 4) Use Energon to process shards and create its metadata/spec |
| 39 | +energon prepare "${DATASET_PATH}" |
| 40 | +# In the interactive prompts: |
| 41 | +# - Enter a train/val/test split, e.g., "8,1,1" |
| 42 | +# - When asked for the sample type, choose: "Crude sample (plain dict for cooking)" |
| 43 | +``` |
| 44 | + |
| 45 | +What gets produced: |
| 46 | +- Each shard contains: |
| 47 | + - pth: contain WAN video latents |
| 48 | + - pickle: contain text embeddings |
| 49 | + - json: contain useful side-info (text caption, sizes, processing choices, etc.) |
| 50 | +- Energon writes a `.nv-meta` directory with dataset info and a `dataset.yaml` you can version/control. |
| 51 | + |
| 52 | +You’re ready to launch training. In the training config, we will point the WAN config (or CLI overrides) to the processed data output direcotry as `dataset.path=${DATASET_PATH}`. |
| 53 | + |
| 54 | +--- |
| 55 | + |
| 56 | +### 🐳 Build Container |
| 57 | + |
| 58 | +Please follow the instructions in the container section of the main README: |
| 59 | + |
| 60 | +- DFM container guide: https://github.com/NVIDIA-NeMo/DFM#-built-your-own-container |
| 61 | + |
| 62 | +--- |
| 63 | + |
| 64 | +### 🏋️ Pretraining |
| 65 | + |
| 66 | +This recipe leverages sequence packing to maximize throughput. When a batch containing videos with different shapes or resolution, naive batching and padding method require significant numner of padded tokens, due to the inherit size of videos. Sequence packing stacks multiple samples (with dirrent resolutions) into a single sequence instead of padding; hence no computation is wasted on padded tokens. When using sequence packing: |
| 67 | +- Set `train.micro_batch_size=1` and `dataset.micro_batch_size=1` |
| 68 | +- Ensure `model.qkv_format=thd` (required with context parallelism and recommended with sequence packing) |
| 69 | + |
| 70 | +Multiple parallelism techniques including tensor, sequence, and context parallelism are supported and configurable per your hardware. |
| 71 | + |
| 72 | +Wan training is driven by `examples/megatron/recipes/wan/pretrain_wan.py`, which supports both a YAML config file and CLI overrides. |
| 73 | + |
| 74 | +The script exposes a `--training-mode` with `pretrain` and `finetune` presets for flow-matching hyperparameters as a starting point for experiments. This presets specify that pretraining uses noisier, biased sampling (e.g., logit-normal, higher logit_std, lower flow_shift) for stability and broad learning, while finetuning uses uniform, lower-noise settings (e.g., uniform sampling, lower logit_std, higher flow_shift) to refine details and improve quality. |
| 75 | + |
| 76 | +**Notes**: If you use `logger.wandb_project` and `logger.wandb_exp_name`, export `WANDB_API_KEY`. |
| 77 | + |
| 78 | +#### Pretraining script example |
| 79 | + |
| 80 | +We provide example scripts for running 1.3B and 14B model sizes on mock dataset (see `wan_1_3B.yaml` and `wan_14B.yaml` under `examples/megatron/recipes/wan/conf`). From these starting points, users can set their own configuration by copy one of the example override configs and update it with your settings (e.g., with actual processed data path, and specific configurations based on available hardware, etc.). Users can learn more about arugments detail at [Megatron-Bridge docs](https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/docs/megatron-lm-to-megatron-bridge.md). |
| 81 | + |
| 82 | + |
| 83 | +```bash |
| 84 | +cp examples/megatron/recipes/wan/conf/wan_1_3B.yaml examples/megatron/recipes/wan/conf/my_wan.yaml |
| 85 | +# Edit my_wan.yaml to set: |
| 86 | +# - dataset.path: Path to your WebDataset directory |
| 87 | +# - train.global_batch_size/micro_batch_size: Keep micro_batch_size=1 |
| 88 | +# - model.tensor_model_parallel_size / model.context_parallel_size: Based on GPUs |
| 89 | +# - checkpoint.save and checkpoint.load: Checkpoint directory |
| 90 | +``` |
| 91 | + |
| 92 | +Then run: |
| 93 | + |
| 94 | +```bash |
| 95 | +uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node $num_gpus \ |
| 96 | + examples/megatron/recipes/wan/pretrain_wan.py \ |
| 97 | + --training-mode pretrain \ |
| 98 | + --config-file examples/megatron/recipes/wan/conf/my_wan.yaml |
| 99 | +``` |
| 100 | + |
| 101 | +You can also override any config values from the command line. For example: |
| 102 | + |
| 103 | +```bash |
| 104 | +uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node $num_gpus \ |
| 105 | + examples/megatron/recipes/wan/pretrain_wan.py \ |
| 106 | + --config-file examples/megatron/recipes/wan/conf/my_wan.yaml \ |
| 107 | + --training-mode pretrain \ |
| 108 | + dataset.path=/opt/wan_webdataset \ |
| 109 | + train.global_batch_size=8 \ |
| 110 | + train.micro_batch_size=1 \ |
| 111 | + model.tensor_model_parallel_size=2 \ |
| 112 | + model.context_parallel_size=4 \ |
| 113 | + checkpoint.save=/opt/pretrained_checkpoints \ |
| 114 | + checkpoint.load=/opt/pretrained_checkpoints |
| 115 | +``` |
| 116 | + |
| 117 | +#### 🧪 Quick Start with Mock Dataset |
| 118 | +If you want to run without a real dataset (for debugging or performance measurement), pass `--mock`: |
| 119 | + |
| 120 | +```bash |
| 121 | +uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node $num_gpus \ |
| 122 | + examples/megatron/recipes/wan/pretrain_wan.py \ |
| 123 | + --config-file examples/megatron/recipes/wan/conf/wan_1_3B.yaml \ |
| 124 | + --training-mode pretrain \ |
| 125 | + --mock |
| 126 | +``` |
| 127 | + |
| 128 | +You may adjust mock shapes (`F_latents`, `H_latents`, `W_latents`) and packing behavior (`number_packed_samples`) in `WanMockDataModuleConfig` (see `dfm/src/megatron/recipes/wan/wan.py`) to simulate different data scenarios. |
| 129 | + |
| 130 | +--- |
| 131 | + |
| 132 | +### 🎬 Inference |
| 133 | + |
| 134 | +After training, users can run inferencing with `examples/megatron/recipes/wan/inference_wan.py`. Set `--checkpoint_step` to use specific checkpoint for inference. Set `--sizes` and `--frame_nums` to specify video shape (frames, height, width). Set `--sample_steps` (default to 50) for number of noise diffusion steps. |
| 135 | + |
| 136 | +```bash |
| 137 | +uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node 1 \ |
| 138 | + examples/megatron/recipes/wan/inference_wan.py \ |
| 139 | + --task t2v-1.3B \ |
| 140 | + --frame_nums 81 \ |
| 141 | + --sizes 480*832 \ |
| 142 | + --checkpoint_dir /opt/pretrained_checkpoints \ |
| 143 | + --checkpoint_step 10000 \ |
| 144 | + --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ |
| 145 | + --sample_steps 50 |
| 146 | +``` |
| 147 | + |
| 148 | +**Note**: Current inference path is single-GPU. Parallel inference is not yet supported. |
| 149 | + |
| 150 | +--- |
| 151 | + |
| 152 | +### ⚡ Parallelism Support |
| 153 | + |
| 154 | +The table below shows current parallelism support for different model sizes: |
| 155 | + |
| 156 | +| Model | Data Parallel | Tensor Parallel | Sequence Parallel | Context Parallel | FSDP | |
| 157 | +|---|---|---|---|---|---| |
| 158 | +| 1.3B | ✅ | ✅ | ✅ | ✅ |Coming Soon| |
| 159 | +| 14B | ✅ | ✅ | ✅ | ✅ |Coming Soon| |
| 160 | + |
| 161 | + |
| 162 | +### References |
| 163 | +Wan Team. (2025). Wan: Open and advanced large-scale video generative models (Wan 2.1). GitHub. https://github.com/Wan-Video/Wan2.1/ |
0 commit comments