Skip to content

Commit 852a5cb

Browse files
huvunvidiaHuy Vu2
andauthored
Add docs for Megatron Wan (#38)
* add focs * updated README for Wab * update README wan * relocate teadme --------- Co-authored-by: Huy Vu2 <[email protected]>
1 parent 407d986 commit 852a5cb

File tree

4 files changed

+244
-1
lines changed

4 files changed

+244
-1
lines changed

dfm/src/megatron/model/wan/wan_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]):
5050
parallel_output: bool = True
5151
bf16: bool = False
5252
params_dtype: torch.dtype = torch.float32
53-
qkv_format: str = "sbhd" # "thd". NOTE: if we use context parallelism, we need to use "thd"
53+
qkv_format: str = "thd" # "sbhd". NOTE: if we use context parallelism, we need to use "thd"
5454
apply_rope_fusion: bool = True
5555
bias_activation_fusion: bool = True
5656
# these attributes are unused for images/videos, we just set because bridge training requires for LLMs
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
## 🚀 Megatron WAN
2+
3+
### 📋 Overview
4+
An open-source implementation of [WAN 2.1](https://github.com/Wan-Video/Wan2.1) (large-scale text-to-video/image generative models) built on top of [Megatron-Core](https://github.com/NVIDIA/Megatron-LM) and [Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) for scalable and efficient training. It supports advanced parallelism strategies (data, tensor, sequence, and context parallelism) and optimized kernels (e.g., Transformer Engine fused attention).
5+
6+
---
7+
8+
### 📦 Dataset Preparation
9+
This recipe uses NVIDIA's [Megatron-Energon](https://github.com/NVIDIA/Megatron-Energon) as an efficient multi-modal data loader. Datasets should be in the WebDataset-compatible format (typically sharded `.tar` archives). Energon supports large-scale distributed loading, sharding, and sampling for video-text and image-text pairs. Set `dataset.path` to your WebDataset directory or shard pattern. See Megatron-Energon docs for format details, subflavors, and advanced options.
10+
11+
If you do not have a dataset yet or only need to validate performance/plumbing, see the "Quick Start with Mock Dataset" section below.
12+
13+
---
14+
15+
#### 🗂️ Dataset Preparation Example
16+
Starting with a directory containing raw .mp4 videos and their corresponding .json metadata files containing captions, you can turn the data into WAN-ready WebDataset shards using our helper script. We then use Energon to process those shards and create its metadata. After this, you can set training script's `dataset.path` argument to the output processed data folder and start training.
17+
18+
```bash
19+
# 1) Define your input (raw videos) and output (WebDataset shards) folders. For example:
20+
DATASET_SRC=/opt/raw_videos # contains .mp4 and per-video .jsonl captions
21+
DATASET_PATH=/opt/wan_webdataset # output WebDataset shards
22+
23+
# 2) (Optional) If your WAN models require auth on first download
24+
export HF_TOKEN=<your_huggingface_token>
25+
26+
# 3) Create WAN shards with latents + text embeddings
27+
# Wan's VAE encoder and T5 encoder is used to extract videos' latents and caption embeddings offline before training, using the following core arugments:
28+
# --height/--width: control resize target (832x480 is supported for both 1.3B and 14B model)
29+
# --center-crop: run center crop to exact target size after resize
30+
uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node 1 \
31+
examples/megatron/recipes/wan/prepare_energon_dataset_wan.py \
32+
--video_folder "${DATASET_SRC}" \
33+
--output_dir "${DATASET_PATH}" \
34+
--model "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" \
35+
--height 480 --width 832 \
36+
--center-crop
37+
38+
# 4) Use Energon to process shards and create its metadata/spec
39+
energon prepare "${DATASET_PATH}"
40+
# In the interactive prompts:
41+
# - Enter a train/val/test split, e.g., "8,1,1"
42+
# - When asked for the sample type, choose: "Crude sample (plain dict for cooking)"
43+
```
44+
45+
What gets produced:
46+
- Each shard contains:
47+
- pth: contain WAN video latents
48+
- pickle: contain text embeddings
49+
- json: contain useful side-info (text caption, sizes, processing choices, etc.)
50+
- Energon writes a `.nv-meta` directory with dataset info and a `dataset.yaml` you can version/control.
51+
52+
You’re ready to launch training. In the training config, we will point the WAN config (or CLI overrides) to the processed data output direcotry as `dataset.path=${DATASET_PATH}`.
53+
54+
---
55+
56+
### 🐳 Build Container
57+
58+
Please follow the instructions in the container section of the main README:
59+
60+
- DFM container guide: https://github.com/NVIDIA-NeMo/DFM#-built-your-own-container
61+
62+
---
63+
64+
### 🏋️ Pretraining
65+
66+
This recipe leverages sequence packing to maximize throughput. When a batch containing videos with different shapes or resolution, naive batching and padding method require significant numner of padded tokens, due to the inherit size of videos. Sequence packing stacks multiple samples (with dirrent resolutions) into a single sequence instead of padding; hence no computation is wasted on padded tokens. When using sequence packing:
67+
- Set `train.micro_batch_size=1` and `dataset.micro_batch_size=1`
68+
- Ensure `model.qkv_format=thd` (required with context parallelism and recommended with sequence packing)
69+
70+
Multiple parallelism techniques including tensor, sequence, and context parallelism are supported and configurable per your hardware.
71+
72+
Wan training is driven by `examples/megatron/recipes/wan/pretrain_wan.py`, which supports both a YAML config file and CLI overrides.
73+
74+
The script exposes a `--training-mode` with `pretrain` and `finetune` presets for flow-matching hyperparameters as a starting point for experiments. This presets specify that pretraining uses noisier, biased sampling (e.g., logit-normal, higher logit_std, lower flow_shift) for stability and broad learning, while finetuning uses uniform, lower-noise settings (e.g., uniform sampling, lower logit_std, higher flow_shift) to refine details and improve quality.
75+
76+
**Notes**: If you use `logger.wandb_project` and `logger.wandb_exp_name`, export `WANDB_API_KEY`.
77+
78+
#### Pretraining script example
79+
80+
We provide example scripts for running 1.3B and 14B model sizes on mock dataset (see `wan_1_3B.yaml` and `wan_14B.yaml` under `examples/megatron/recipes/wan/conf`). From these starting points, users can set their own configuration by copy one of the example override configs and update it with your settings (e.g., with actual processed data path, and specific configurations based on available hardware, etc.). Users can learn more about arugments detail at [Megatron-Bridge docs](https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/docs/megatron-lm-to-megatron-bridge.md).
81+
82+
83+
```bash
84+
cp examples/megatron/recipes/wan/conf/wan_1_3B.yaml examples/megatron/recipes/wan/conf/my_wan.yaml
85+
# Edit my_wan.yaml to set:
86+
# - dataset.path: Path to your WebDataset directory
87+
# - train.global_batch_size/micro_batch_size: Keep micro_batch_size=1
88+
# - model.tensor_model_parallel_size / model.context_parallel_size: Based on GPUs
89+
# - checkpoint.save and checkpoint.load: Checkpoint directory
90+
```
91+
92+
Then run:
93+
94+
```bash
95+
uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node $num_gpus \
96+
examples/megatron/recipes/wan/pretrain_wan.py \
97+
--training-mode pretrain \
98+
--config-file examples/megatron/recipes/wan/conf/my_wan.yaml
99+
```
100+
101+
You can also override any config values from the command line. For example:
102+
103+
```bash
104+
uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node $num_gpus \
105+
examples/megatron/recipes/wan/pretrain_wan.py \
106+
--config-file examples/megatron/recipes/wan/conf/my_wan.yaml \
107+
--training-mode pretrain \
108+
dataset.path=/opt/wan_webdataset \
109+
train.global_batch_size=8 \
110+
train.micro_batch_size=1 \
111+
model.tensor_model_parallel_size=2 \
112+
model.context_parallel_size=4 \
113+
checkpoint.save=/opt/pretrained_checkpoints \
114+
checkpoint.load=/opt/pretrained_checkpoints
115+
```
116+
117+
#### 🧪 Quick Start with Mock Dataset
118+
If you want to run without a real dataset (for debugging or performance measurement), pass `--mock`:
119+
120+
```bash
121+
uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node $num_gpus \
122+
examples/megatron/recipes/wan/pretrain_wan.py \
123+
--config-file examples/megatron/recipes/wan/conf/wan_1_3B.yaml \
124+
--training-mode pretrain \
125+
--mock
126+
```
127+
128+
You may adjust mock shapes (`F_latents`, `H_latents`, `W_latents`) and packing behavior (`number_packed_samples`) in `WanMockDataModuleConfig` (see `dfm/src/megatron/recipes/wan/wan.py`) to simulate different data scenarios.
129+
130+
---
131+
132+
### 🎬 Inference
133+
134+
After training, users can run inferencing with `examples/megatron/recipes/wan/inference_wan.py`. Set `--checkpoint_step` to use specific checkpoint for inference. Set `--sizes` and `--frame_nums` to specify video shape (frames, height, width). Set `--sample_steps` (default to 50) for number of noise diffusion steps.
135+
136+
```bash
137+
uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node 1 \
138+
examples/megatron/recipes/wan/inference_wan.py \
139+
--task t2v-1.3B \
140+
--frame_nums 81 \
141+
--sizes 480*832 \
142+
--checkpoint_dir /opt/pretrained_checkpoints \
143+
--checkpoint_step 10000 \
144+
--prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
145+
--sample_steps 50
146+
```
147+
148+
**Note**: Current inference path is single-GPU. Parallel inference is not yet supported.
149+
150+
---
151+
152+
### ⚡ Parallelism Support
153+
154+
The table below shows current parallelism support for different model sizes:
155+
156+
| Model | Data Parallel | Tensor Parallel | Sequence Parallel | Context Parallel | FSDP |
157+
|---|---|---|---|---|---|
158+
| 1.3B |||||Coming Soon|
159+
| 14B |||||Coming Soon|
160+
161+
162+
### References
163+
Wan Team. (2025). Wan: Open and advanced large-scale video generative models (Wan 2.1). GitHub. https://github.com/Wan-Video/Wan2.1/
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Example override file
16+
17+
# To override a parameter, ensure the structure matches the ConfigContainer
18+
# and its sub-configurations (e.g., model, train, etc.)
19+
# Top-level ConfigContainer fields are dataclasses themselves
20+
21+
model:
22+
23+
crossattn_emb_size: 5120
24+
hidden_size: 5120
25+
ffn_hidden_size: 13824
26+
num_attention_heads: 40
27+
num_layers: 40
28+
tensor_model_parallel_size: 2
29+
pipeline_model_parallel_size: 1
30+
context_parallel_size: 4
31+
sequence_parallel: true
32+
recompute_granularity: full
33+
recompute_method: uniform
34+
recompute_num_layers: 1
35+
36+
train:
37+
global_batch_size: 1
38+
micro_batch_size: 1
39+
40+
dataset:
41+
global_batch_size: 1
42+
micro_batch_size: 1
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Example override file
16+
17+
# To override a parameter, ensure the structure matches the ConfigContainer
18+
# and its sub-configurations (e.g., model, train, etc.)
19+
# Top-level ConfigContainer fields are dataclasses themselves
20+
21+
model:
22+
crossattn_emb_size: 1536
23+
hidden_size: 1536
24+
ffn_hidden_size: 8960
25+
num_attention_heads: 12
26+
num_layers: 30
27+
tensor_model_parallel_size: 1
28+
pipeline_model_parallel_size: 1
29+
context_parallel_size: 8
30+
sequence_parallel: false
31+
32+
train:
33+
global_batch_size: 2
34+
micro_batch_size: 1
35+
36+
dataset:
37+
global_batch_size: 2
38+
micro_batch_size: 1

0 commit comments

Comments
 (0)