Skip to content

Commit 80f41ed

Browse files
authored
Merge branch 'main' into llane/site-config-and-skeleton
2 parents 7072c2a + 1061749 commit 80f41ed

40 files changed

+3982
-268
lines changed

3rdparty/Megatron-Bridge

Submodule Megatron-Bridge updated 298 files

README.md

Lines changed: 174 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,175 @@
1-
# NeMo DFM: Diffusion Foundation Models collection
2-
3-
NeMo DFM is a state-of-the-art framework for fast, large-scale training and inference of video world models. It unifies the latest diffusion-based and autoregressive techniques, prioritizing efficiency and performance from research prototyping to production deployment.
4-
5-
## Projects
6-
7-
This collection consists of 4 projects:
8-
1. [Scalable diffusion training framework](nemo_vfm/diffusion/readme.rst)
9-
2. [Accelerated diffusion world models](nemo_vfm/physicalai/Cosmos/cosmos1/models/diffusion/README.md)
10-
3. [Accelerated autoregressive world models](nemo_vfm/physicalai/Cosmos/cosmos1/models/autoregressive/README.md)
11-
4. [Sparse attention for efficient diffusion inference](nemo_vfm/sparse_attention/README.md)
12-
13-
## Citations
14-
15-
If you find our code useful, please consider citing the following papers:
16-
```bibtex
17-
@article{patel2025training,
18-
title={Training Video Foundation Models with NVIDIA NeMo},
19-
author={Patel, Zeeshan and He, Ethan and Mannan, Parth and Ren, Xiaowei and Wolf, Ryan and Agarwal, Niket and Huffman, Jacob and Wang, Zhuoyao and Wang, Carl and Chang, Jack and others},
20-
journal={arXiv preprint arXiv:2503.12964},
21-
year={2025}
22-
}
23-
24-
@article{agarwal2025cosmos,
25-
title={Cosmos world foundation model platform for physical ai},
26-
author={Agarwal, Niket and Ali, Arslan and Bala, Maciej and Balaji, Yogesh and Barker, Erik and Cai, Tiffany and Chattopadhyay, Prithvijit and Chen, Yongxin and Cui, Yin and Ding, Yifan and others},
27-
journal={arXiv preprint arXiv:2501.03575},
28-
year={2025}
29-
}
1+
<div align="center">
2+
3+
# NeMo DFM: Diffusion Foundation Models
4+
5+
6+
<!-- We are still using Mbridge CICD NeMo. @pablo can we get our own? and the same for star gazer-->
7+
8+
<!-- Not includeing codecov for now since we have not worked on it extensively-->
9+
10+
[![CICD NeMo](https://github.com/NVIDIA-NeMo/DFM/actions/workflows/cicd-main.yml/badge.svg)](https://github.com/NVIDIA-NeMo/DFM/actions/workflows/cicd-main.yml)
11+
[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/release/python-3100/)
12+
[![GitHub Stars](https://img.shields.io/github/stars/NVIDIA-NeMo/DFM.svg?style=social&label=Star&cacheSeconds=14400)](https://github.com/NVIDIA-NeMo/DFM/stargazers/)
13+
14+
[Documentation](https://github.com/NVIDIA-NeMo/DFM/tree/main/docs) | [Supported Models](#supported-models) | [Examples](https://github.com/NVIDIA-NeMo/DFM/tree/main/examples) | [Contributing](https://github.com/NVIDIA-NeMo/DFM/tree/main/CONTRIBUTING.md)
15+
16+
</div>
17+
18+
## Overview
19+
20+
NeMo DFM (Diffusion Foundation Models) is a library under [NeMo Framework](https://github.com/NVIDIA-NeMo), focusing on diffusion models for **Video**, **Image**, and **Text** generation. It unifies cutting-edge diffusion-based architectures and training techniques, prioritizing efficiency and performance from research prototyping to production deployment.
21+
22+
**Dual-Path Architecture**: DFM provides two complementary training paths to maximize flexibility:
23+
24+
- **🌉 Megatron Bridge Path**: Built on [NeMo Megatron Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) which leverages [Megatron Core](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core) for maximum scalability with n-D parallelism (TP, PP, CP, EP, VPP, DP)
25+
- **🚀 AutoModel Path**: Built on [NeMo AutoModel](https://github.com/NVIDIA-NeMo/Automodel) for PyTorch DTensor-native SPMD training, for easy experimentation and also Day-0 support on 🤗 Hugging Face models.
26+
27+
Choose the path that best fits your workflow—or use both for different stages of development!
28+
29+
<!-- Once we have updated images of how DFM fits into NeMo journey. Put them here. @Eliiot can help.-->
30+
## 🔧 Installation
31+
32+
### 🐳 Build your own Container
33+
34+
#### 1. Build the container
35+
```bash
36+
# Initialize all submodules (Megatron-Bridge, Automodel, and nested Megatron-LM)
37+
git submodule update --init --recursive
38+
39+
# Build the container
40+
docker build -f docker/Dockerfile.ci -t dfm:dev .
41+
```
42+
43+
#### 2. Start the container
44+
45+
```bash
46+
docker run --rm -it --gpus all \
47+
--entrypoint bash \
48+
-v $(pwd):/opt/DFM -it dfm:dev
49+
```
50+
51+
52+
53+
### 📦 Using DFM Docker (Coming Soon)
54+
55+
## ⚡ Quickstart
56+
57+
### Megatron Bridge Path
58+
59+
#### Run a Recipe
60+
You can find all predefined recipes under [recipes](https://github.com/NVIDIA-NeMo/DFM/tree/main/examples/megatron/recipes) directory.
61+
62+
> **Note:** You will have to use [uv](https://docs.astral.sh/uv/) to run the recipes. Please use `--group` as `megatron-bridge`.
63+
64+
```bash
65+
uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node $num_gpus \
66+
examples/megatron/recipes/wan/pretrain_wan.py \
67+
--config-file examples/megatron/recipes/wan/conf/wan_1_3B.yaml \
68+
--training-mode pretrain \
69+
--mock
70+
```
71+
72+
### AutoModel Path
73+
74+
Train with PyTorch-native DTensor parallelism and direct 🤗 HF integration:
75+
76+
#### Run a Recipe
77+
78+
You can find pre-configured recipes under [automodel/finetune](https://github.com/NVIDIA-NeMo/DFM/tree/main/examples/automodel/finetune) and [automodel/pretrain](https://github.com/NVIDIA-NeMo/DFM/tree/main/examples/automodel/pretrain) directories.
79+
80+
> Note: AutoModel examples live under `dfm/examples/automodel`. Use [uv](https://docs.astral.sh/uv/) with `--group automodel`. Configs are YAML-driven; pass `-c <path>` to override the default.
81+
82+
The fine-tune recipe sets up WAN 2.1 Text-to-Video training with Flow Matching using FSDP2 Hybrid Sharding.
83+
It parallelizes heavy transformer blocks while keeping lightweight modules (e.g., VAE) unsharded for efficiency.
84+
Adjust batch sizes, LR, and parallel sizes in `dfm/examples/automodel/finetune/wan2_1_t2v_flow.yaml`.
85+
The generation script demonstrates distributed inference with AutoModel DTensor managers, producing an MP4 on rank 0. You can tweak frame size, frames, steps, and CFG in flags.
86+
87+
```bash
88+
# Fine-tune WAN 2.1 T2V with FSDP2 (single node, 8 GPUs)
89+
uv run --group automodel torchrun --nproc-per-node=8 \
90+
dfm/examples/automodel/finetune/finetune.py \
91+
-c dfm/examples/automodel/finetune/wan2_1_t2v_flow.yaml
92+
93+
# Generate videos with FSDP2 (distributed inference)
94+
uv run --group automodel torchrun --nproc-per-node=8 \
95+
dfm/examples/automodel/generate/wan_generate.py
3096
```
97+
98+
## 🚀 Key Features
99+
100+
### Dual Training Paths
101+
102+
**Megatron Bridge** delivers maximum throughput and scalability with near-linear performance to thousands of nodes. **AutoModel** provides an easy on-ramp for experimentation and research with PyTorch-native SPMD training.
103+
104+
### Shared Capabilities
105+
106+
- **🎥 Multi-Modal Diffusion**: Support for video, image, and text generation
107+
- **🔬 Advanced Samplers**: EDM, Flow Matching, and custom diffusion schedules
108+
- **🎭 Flexible Architectures**: DiT (Diffusion Transformers), WAN (World Action Networks)
109+
- **📊 Efficient Data Loading**: Data pipelines with sequence packing
110+
- **💾 Distributed Checkpointing**: SafeTensors-based sharded checkpoints
111+
- **🌟 Memory Optimization**: Gradient checkpointing, mixed precision, efficient attention
112+
- **🤗 HuggingFace Integration**: Seamless integration with the HF ecosystem
113+
114+
## Supported Models
115+
116+
DFM provides out-of-the-box support for state-of-the-art diffusion architectures:
117+
118+
| Model | Type | Megatron Bridge | AutoModel | Description |
119+
|-------|------|-----------------|-----------|-------------|
120+
| **DiT** | Image/Video | [pretrain](https://github.com/NVIDIA-NeMo/DFM/blob/main/examples/megatron/recipes/dit/pretrain_dit_model.py), [inference](https://github.com/NVIDIA-NeMo/DFM/blob/main/examples/megatron/recipes/dit/inference_dit_model.py) | 🔜 | Diffusion Transformers with scalable architecture |
121+
| **WAN 2.1** | Video | [inference](https://github.com/NVIDIA-NeMo/DFM/blob/main/examples/megatron/recipes/wan/inference_wan.py), [pretrain, finetune](https://github.com/NVIDIA-NeMo/DFM/blob/main/examples/megatron/recipes/wan/pretrain_wan.py) | [pretrain](https://github.com/NVIDIA-NeMo/DFM/tree/main/examples/automodel/pretrain), [finetune](https://github.com/NVIDIA-NeMo/DFM/tree/main/examples/automodel/finetune),[inference](https://github.com/NVIDIA-NeMo/DFM/blob/main/examples/automodel/generate/wan_validate.py) | World Action Networks for video generation |
122+
123+
## Performance Benchmarking
124+
125+
For detailed performance benchmarks including throughput metrics across different GPU systems and model configurations, see the (Performance Summary)[https://github.com/NVIDIA-NeMo/DFM/blob/main/docs/performance-summary.md] in our documentation.
126+
127+
## Project Structure
128+
129+
```
130+
DFM/
131+
├── dfm/
132+
│ └── src/
133+
│ ├── megatron/ # Megatron Bridge path
134+
│ │ ├── base/ # Base utilities for Megatron
135+
│ │ ├── data/ # Data loaders and task encoders
136+
│ │ │ ├── common/ # Shared data utilities
137+
│ │ │ ├── <model_name>/ # model-specific data handling
138+
│ │ ├── model/ # Model implementations
139+
│ │ │ ├── common/ # Shared model components
140+
│ │ │ ├── <model_name>/ # model-specific implementations
141+
│ │ └── recipes/ # Training recipes
142+
│ │ ├── <model_name>/ # model-specific training configs
143+
│ ├── automodel # AutoModel path (DTensor-native)
144+
│ │ ├── _diffusers/ # Diffusion pipeline integrations
145+
│ │ ├── datasets/ # Dataset implementations
146+
│ │ ├── distributed/ # Parallelization strategies
147+
│ │ ├── flow_matching/ # Flow matching implementations
148+
│ │ ├── recipes/ # Training scripts
149+
│ │ └── utils/ # Utilities and validation
150+
│ └── common/ # Shared across both paths
151+
│ ├── data/ # Common data utilities
152+
│ └── utils/ # Batch ops, video utils, etc.
153+
├── examples/ # Example scripts and configs
154+
```
155+
156+
## 🤝 Contributing
157+
158+
We welcome contributions! Please see our Contributing Guide for details on:
159+
160+
- Setting up your development environment
161+
- Code style and testing guidelines
162+
- Submitting pull requests
163+
- Reporting issues
164+
165+
For questions or discussions, please open an issue on GitHub.
166+
167+
## Acknowledgements
168+
169+
NeMo DFM builds upon the excellent work of:
170+
171+
- [Megatron-core](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core) - Advanced model parallelism
172+
- [Megatron Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) - HuggingFace ↔ Megatron bridge
173+
- [NeMo AutoModel](https://github.com/NVIDIA-NeMo/Automodel) - PyTorch-native SPMD training
174+
- [PyTorch Distributed](https://pytorch.org/docs/stable/distributed.html) - Foundation for distributed training
175+
- [Diffusers](https://github.com/huggingface/diffusers) - Diffusion model implementations

dfm/src/common/utils/save_video.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,4 @@ def save_video(
4444
"output_params": ["-f", "mp4"],
4545
}
4646

47-
print("video_save_path", video_save_path)
4847
imageio.mimsave(video_save_path, grid, "mp4", **kwargs)

dfm/src/megatron/data/common/diffusion_energon_datamodule.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ def __post_init__(self):
5555
self.sequence_length = self.dataset.seq_length
5656

5757
def build_datasets(self, context: DatasetBuildContext):
58-
return self.dataset.train_dataloader(), self.dataset.val_dataloader(), self.dataset.test_dataloader()
58+
return (
59+
iter(self.dataset.train_dataloader()),
60+
iter(self.dataset.val_dataloader()),
61+
iter(self.dataset.val_dataloader()),
62+
)
5963

6064

6165
class DiffusionDataModule(EnergonMultiModalDataModule):

dfm/src/megatron/data/common/diffusion_sample.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,25 +80,35 @@ def to_dict(self) -> dict:
8080
def __add__(self, other: Any) -> int:
8181
"""Adds the sequence length of this sample with another sample or integer."""
8282
if isinstance(other, DiffusionSample):
83-
# Combine the values of the two instances
84-
return self.seq_len_q.item() + other.seq_len_q.item()
83+
# Use padded length if available (for CP), otherwise use unpadded
84+
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
85+
other_len = other.seq_len_q_padded.item() if other.seq_len_q_padded is not None else other.seq_len_q.item()
86+
return self_len + other_len
8587
elif isinstance(other, int):
86-
# Add an integer to the value
87-
return self.seq_len_q.item() + other
88+
# Use padded length if available (for CP), otherwise use unpadded
89+
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
90+
return self_len + other
8891
raise NotImplementedError
8992

9093
def __radd__(self, other: Any) -> int:
9194
"""Handles reverse addition for summing with integers."""
9295
# This is called if sum or other operations start with a non-DiffusionSample object.
9396
# e.g., sum([DiffusionSample(1), DiffusionSample(2)]) -> the 0 + DiffusionSample(1) calls __radd__.
9497
if isinstance(other, int):
95-
return self.seq_len_q.item() + other
98+
# Use padded length if available (for CP), otherwise use unpadded
99+
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
100+
return self_len + other
96101
raise NotImplementedError
97102

98103
def __lt__(self, other: Any) -> bool:
99104
"""Compares this sample's sequence length with another sample or integer."""
100105
if isinstance(other, DiffusionSample):
101-
return self.seq_len_q.item() < other.seq_len_q.item()
106+
# Use padded length if available (for CP), otherwise use unpadded
107+
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
108+
other_len = other.seq_len_q_padded.item() if other.seq_len_q_padded is not None else other.seq_len_q.item()
109+
return self_len < other_len
102110
elif isinstance(other, int):
103-
return self.seq_len_q.item() < other
111+
# Use padded length if available (for CP), otherwise use unpadded
112+
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
113+
return self_len < other
104114
raise NotImplementedError

dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
self,
5757
*args,
5858
max_frames: int = None,
59-
text_embedding_padding_size: int = 512,
59+
text_embedding_max_length: int = 512,
6060
seq_length: int = None,
6161
patch_spatial: int = 2,
6262
patch_temporal: int = 1,
@@ -65,7 +65,7 @@ def __init__(
6565
):
6666
super().__init__(*args, **kwargs)
6767
self.max_frames = max_frames
68-
self.text_embedding_padding_size = text_embedding_padding_size
68+
self.text_embedding_max_length = text_embedding_max_length
6969
self.seq_length = seq_length
7070
self.patch_spatial = patch_spatial
7171
self.patch_temporal = patch_temporal

dfm/src/megatron/data/common/sequence_packing_utils.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -71,35 +71,3 @@ def first_fit_decreasing(seqlens: List[int], pack_size: int) -> List[List[int]]:
7171
"""
7272
sorted_seqlens = sorted(seqlens, reverse=True)
7373
return first_fit(sorted_seqlens, pack_size)
74-
75-
76-
def concat_pad(tensor_list, max_seq_length):
77-
"""
78-
Efficiently concatenates a list of tensors along the first dimension and pads with zeros
79-
to reach max_seq_length.
80-
81-
Args:
82-
tensor_list (list of torch.Tensor): List of tensors to concatenate and pad.
83-
max_seq_length (int): The desired size of the first dimension of the output tensor.
84-
85-
Returns:
86-
torch.Tensor: A tensor of shape [max_seq_length, ...], where ... represents the remaining dimensions.
87-
"""
88-
import torch
89-
90-
# Get common properties from the first tensor
91-
other_shape = tensor_list[0].shape[1:]
92-
dtype = tensor_list[0].dtype
93-
device = tensor_list[0].device
94-
95-
# Initialize the result tensor with zeros
96-
result = torch.zeros((max_seq_length, *other_shape), dtype=dtype, device=device)
97-
98-
current_index = 0
99-
for tensor in tensor_list:
100-
length = tensor.shape[0]
101-
# Directly assign the tensor to the result tensor without checks
102-
result[current_index : current_index + length] = tensor
103-
current_index += length
104-
105-
return result

dfm/src/megatron/data/dit/dit_mock_datamodule.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def mock_batch(
113113
seq_len_kv=seq_len_kv_packed,
114114
seq_len_kv_padded=seq_len_kv_padded_packed,
115115
latent_shape=torch.tensor([[C, T, H, W] for _ in range(number_packed_samples)], dtype=torch.int32),
116-
pos_ids=pos_ids_packed,
116+
pos_ids=pos_ids_packed.unsqueeze(0),
117117
video_metadata=[{"caption": f"Mock video sample {i}"} for i in range(number_packed_samples)],
118118
)
119119

@@ -131,16 +131,19 @@ class DiTMockDataModuleConfig(DatasetProvider):
131131
dataloader_type: str = "external"
132132
task_encoder_seq_length: int = None
133133
F_latents: int = 1
134-
H_latents: int = 64
135-
W_latents: int = 96
134+
H_latents: int = 256
135+
W_latents: int = 512
136136
patch_spatial: int = 2
137137
patch_temporal: int = 1
138-
number_packed_samples: int = 3
138+
number_packed_samples: int = 1
139139
context_seq_len: int = 512
140140
context_embeddings_dim: int = 1024
141141

142142
def __post_init__(self):
143143
mock_ds = _MockDataset(length=1024)
144+
kwargs = {}
145+
if self.num_workers > 0:
146+
kwargs["prefetch_factor"] = 8
144147
self._train_dl = DataLoader(
145148
mock_ds,
146149
batch_size=self.micro_batch_size,
@@ -157,6 +160,8 @@ def __post_init__(self):
157160
),
158161
shuffle=False,
159162
drop_last=False,
163+
pin_memory=True,
164+
**kwargs,
160165
)
161166
self._train_dl = iter(self._train_dl)
162167
self.sequence_length = self.seq_length

dfm/src/megatron/data/dit/dit_taskencoder.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ class DiTTaskEncoder(DiffusionTaskEncoderWithSequencePacking):
3131
Attributes:
3232
cookers (list): A list of Cooker objects used for processing.
3333
max_frames (int, optional): The maximum number of frames to consider from the video. Defaults to None.
34-
text_embedding_padding_size (int): The padding size for text embeddings. Defaults to 512.
34+
text_embedding_max_length (int): The maximum length for text embeddings. Defaults to 512.
3535
Methods:
36-
__init__(*args, max_frames=None, text_embedding_padding_size=512, **kwargs):
36+
__init__(*args, max_frames=None, text_embedding_max_size=512, **kwargs):
3737
Initializes the BasicDiffusionTaskEncoder with optional maximum frames and text embedding padding size.
3838
encode_sample(sample: dict) -> dict:
3939
Encodes a given sample dictionary containing video and text data.
@@ -71,7 +71,6 @@ def encode_sample(self, sample: dict) -> DiffusionSample:
7171
// self.patch_spatial**2
7272
// self.patch_temporal
7373
)
74-
is_image = T == 1
7574

7675
if seq_len > self.seq_length:
7776
print(f"Skipping sample {sample['__key__']} because seq_len {seq_len} > self.seq_length {self.seq_length}")
@@ -100,8 +99,8 @@ def encode_sample(self, sample: dict) -> DiffusionSample:
10099
t5_text_embeddings = torch.from_numpy(sample["pickle"]).to(torch.bfloat16)
101100
t5_text_embeddings_seq_length = t5_text_embeddings.shape[0]
102101

103-
if t5_text_embeddings_seq_length > self.text_embedding_padding_size:
104-
t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size]
102+
if t5_text_embeddings_seq_length > self.text_embedding_max_length:
103+
t5_text_embeddings = t5_text_embeddings[: self.text_embedding_max_length]
105104
t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16)
106105

107106
pos_ids = rearrange(

0 commit comments

Comments
 (0)