|
1 | | -## Megatron DiT |
| 1 | +## 🚀 Megatron DiT |
2 | 2 |
|
3 | | -### Overview |
4 | | -An open source implementation of Diffusion Transformers (DiTs) that can be used to train text-to-image/video models. The implementation is based on [Megatron-Core](https://github.com/NVIDIA/Megatron-LM) and [Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) to bring both scalability and efficiency. Various parallelization techniques such as tensor, sequence, and context parallelism are currently supported. |
| 3 | +### 📋 Overview |
| 4 | +An open-source implementation of [Diffusion Transformers (DiTs)](https://github.com/facebookresearch/DiT) that can be used to train text-to-image/video models using the [EDMPipeline](https://arxiv.org/abs/2206.00364). The implementation is based on [Megatron-Core](https://github.com/NVIDIA/Megatron-LM) and [Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) to bring both scalability and efficiency. Various parallelization techniques such as tensor, sequence, and context parallelism are currently supported. |
5 | 5 |
|
| 6 | +--- |
6 | 7 |
|
7 | | -### Dataset Preparation |
8 | | -This recipe uses NVIDIA's [Megatron-Energon](https://github.com/NVIDIA/Megatron-Energon) as an efficient multi-modal data loader. Datasets should be in the WebDataset-compatible format (typically sharded `.tar` archives). Energon efficiently supports large-scale distributed loading, sharding, and sampling for multi-modal pairs (e.g., text-image, text-video). Set `dataset.path` to your WebDataset location or shard pattern (e.g., a directory containing shards). See the Megatron-Energon documentation for format details and advanced options. |
| 8 | +### 📦 Dataset Preparation |
| 9 | +This recipe uses NVIDIA's [Megatron-Energon](https://github.com/NVIDIA/Megatron-Energon) as an efficient multi-modal data loader. Datasets should be in the WebDataset-compatible format (typically sharded `.tar` archives). Energon efficiently supports large-scale distributed loading, sharding, and sampling for multi-modal pairs (e.g., text-image, text-video). Set `dataset.path` to your WebDataset location or shard pattern. See the Megatron-Energon documentation for format details and advanced options. |
9 | 10 |
|
10 | | -#### Dataset Preparation Example |
| 11 | +#### 🦋 Dataset Preparation Example |
11 | 12 |
|
12 | | -As an example you can use [butterfly-dataset](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) available on Hugging Face. |
| 13 | +As an example, you can use the [butterfly-dataset](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) available on Hugging Face. |
13 | 14 |
|
14 | | -The script below prepares the dataset to be compatible with Energon. t5_folder and tokenizer_cache_dir are optional parameters pointing to a T5 model and Video Tokenizer of your choice, otherwise the code downloads such artifacts. |
15 | | -``` bash |
16 | | -uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node $num_gpus\ |
17 | | - examples/megatron/recipes/dit/prepare_energon_dataset_butterfly.py\ |
18 | | - --t5_cache_dir $t5_folder\ |
| 15 | +The script below prepares the dataset to be compatible with Energon. |
| 16 | +```bash |
| 17 | +uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node $num_gpus \ |
| 18 | + examples/megatron/recipes/dit/prepare_energon_dataset_butterfly.py |
| 19 | +``` |
| 20 | + |
| 21 | +In case you already have the T5 model or video tokenizer downloaded, you can point to them with optional arguments `--t5_cache_dir` and `--tokenizer_cache_dir`. |
| 22 | + |
| 23 | + |
| 24 | +```bash |
| 25 | +uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node $num_gpus \ |
| 26 | + examples/megatron/recipes/dit/prepare_energon_dataset_butterfly.py \ |
| 27 | + --t5_cache_dir $t5_cache_dir \ |
19 | 28 | --tokenizer_cache_dir $tokenizer_cache_dir |
20 | 29 | ``` |
| 30 | + |
21 | 31 | Then you need to run `energon prepare $dataset_path` and choose `CrudeWebdataset` as the sample type: |
22 | 32 |
|
23 | 33 | ```bash |
@@ -69,87 +79,107 @@ Furthermore, you might want to add `subflavors` in your meta dataset specificati |
69 | 79 | Done |
70 | 80 | ``` |
71 | 81 |
|
72 | | -### Pretraining |
73 | | -These scripts assume you're using the Docker container provided by the repo. Use them to pre-train a DiT model on your own dataset. |
| 82 | +--- |
| 83 | +
|
| 84 | +### 🐳 Build Container |
| 85 | +
|
| 86 | +Please follow the instructions in the [container](https://github.com/NVIDIA-NeMo/DFM#-built-your-own-container) section of the main README. |
| 87 | +
|
| 88 | +--- |
| 89 | +
|
| 90 | +### 🏋️ Pretraining |
74 | 91 |
|
75 | | -**Note:** Set the `WANDB_API_KEY` environment variable if you're using the `wandb_project` and `wandb_exp_name` arguments. |
| 92 | +Once you have the dataset and container ready, you can start training the DiT model on your own dataset. This repository leverages [sequence packing](https://docs.nvidia.com/nemo-framework/user-guide/24.09/nemotoolkit/features/optimizations/sequence_packing.html) to maximize training efficiency. Sequence packing stacks multiple samples into a single sequence instead of padding individual samples to a fixed length; therefore, `micro_batch_size` must be set to 1. Additionally, `qkv_format` should be set to `thd` to signal to Transformer Engine that sequence packing is enabled. |
| 93 | +
|
| 94 | +For data loading, Energon provides two key hyperparameters related to sequence packing: `task_encoder_seq_length` and `packing_buffer_size`. The `task_encoder_seq_length` parameter controls the maximum sequence length passed to the model, while `packing_buffer_size` determines the number of samples processed to create different buckets. You can look at `select_samples_to_pack` and `pack_selected_samples` methods of [DiffusionTaskEncoderWithSequencePacking](https://github.com/NVIDIA-NeMo/DFM/blob/main/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py#L50) to get a better sense of these parameters. |
| 95 | +
|
| 96 | +Multiple parallelism techniques including tensor, sequence, and context parallelism are supported and can be configured based on your computational requirements. |
| 97 | +
|
| 98 | +The model architecture can be customized through parameters such as `num_layers` and `num_attention_heads`. A comprehensive list of configuration options is available in the [Megatron-Bridge documentation](https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/docs/megatron-lm-to-megatron-bridge.md). |
| 99 | +
|
| 100 | +
|
| 101 | +**Note:** If using the `wandb_project` and `wandb_exp_name` arguments, ensure the `WANDB_API_KEY` environment variable is set. |
| 102 | +
|
| 103 | +
|
| 104 | +**Note:** During validation, the model generates one sample per GPU at the start of each validation round. These samples are saved to a `validation_generation` folder within `checkpoint_dir` and are also logged to Wandb if the `WANDB_API_KEY` environment variable is configured. To decode the generated latent samples, the model requires access to the video tokenizer used during dataset preparation. Specify the VAE artifacts location using the `vae_cache_folder` argument, otherwise they will be downloaded in the first validation round. |
| 105 | +
|
| 106 | +#### Pretraining script example |
| 107 | +First, copy the example config file and update it with your own settings: |
76 | 108 |
|
77 | 109 | ```bash |
78 | | -uv run --group megatron-bridge python -m torch.distributed.run \ |
79 | | - --nproc-per-node $NUM_GPUS examples/megatron/recipes/dit/pretrain_dit_model.py\ |
80 | | - model.tensor_model_parallel_size=1 \ |
81 | | - model.pipeline_model_parallel_size=1 \ |
82 | | - model.context_parallel_size=1 \ |
83 | | - model.qkv_format=thd \ |
84 | | - model.num_attention_heads=16\ |
85 | | - model.vae_cache_folder=$CACHE_FOLDER\ |
86 | | - dataset.path=$DATA_FOLDER \ |
87 | | - dataset.task_encoder_seq_length=15360\ |
88 | | - dataset.packing_buffer_size=100\ |
89 | | - dataset.num_workers=20\ |
90 | | - checkpoint.save=$CHECKPOINT_FOLDER \ |
91 | | - checkpoint.load=$CHECKPOINT_FOLDER \ |
92 | | - checkpoint.load_optim=true \ |
93 | | - checkpoint.save_interval=1000 \ |
94 | | - train.eval_interval=1000\ |
95 | | - train.train_iters=10000\ |
96 | | - train.eval_iters=32 \ |
97 | | - train.global_batch_size=$NUM_GPUS\ |
98 | | - train.micro_batch_size=1\ |
99 | | - logger.log_interval=10\ |
100 | | - logger.wandb_project="DiT"\ |
101 | | - logger.wandb_exp_name=$WANDB_NAME |
| 110 | +cp examples/megatron/recipes/dit/conf/dit_pretrain_example.yaml examples/megatron/recipes/dit/conf/my_config.yaml |
| 111 | +# Edit my_config.yaml to set: |
| 112 | +# - model.vae_cache_folder: Path to VAE cache folder |
| 113 | +# - dataset.path: Path to your dataset folder |
| 114 | +# - checkpoint.save and checkpoint.load: Path to checkpoint folder |
| 115 | +# - train.global_batch_size: Set to match be divisible by NUM_GPUs |
| 116 | +# - logger.wandb_exp_name: Your experiment name |
102 | 117 | ``` |
103 | 118 |
|
104 | | -### Inference |
105 | | -``` bash |
106 | | -uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node $num_gpus examples/megatron/recipes/dit/inference_dit_model.py \ |
107 | | - --t5_cache_dir $artifact_dir \ |
108 | | - --tokenizer_cache_dir $tokenizer_cache_dir \ |
109 | | - --tokenizer_model Cosmos-0.1-Tokenizer-CV4x8x8\ |
110 | | - --checkpoint_path $checkpoint_dir \ |
111 | | - --num_video_frames 10 \ |
112 | | - --height 240 \ |
113 | | - --width 416 \ |
114 | | - --video_save_path $save_path \ |
115 | | - --prompt $prompt |
| 119 | +Then run: |
| 120 | +
|
| 121 | +```bash |
| 122 | +uv run --group megatron-bridge python -m torch.distributed.run \ |
| 123 | + --nproc-per-node $NUM_GPUS examples/megatron/recipes/dit/pretrain_dit_model.py \ |
| 124 | + --config-file examples/megatron/recipes/dit/conf/my_config.yaml |
116 | 125 | ``` |
117 | 126 |
|
118 | | -### Parallelism Support |
119 | | -The table below shows current parallelism support. |
| 127 | +You can still override any config values from the command line: |
120 | 128 |
|
121 | | - | Model | Data Parallel | Tensor Parallel | Sequence Parallel | Pipeline Parallel | Context Parallel | FSDP | |
122 | | - |---|---|---|---|---|---|---| |
123 | | - | **DiT-XL (700M)** | ✅ | ✅ | ✅ | | ✅ | | |
124 | | - | **DiT 7B** | | | | | | | |
| 129 | +```bash |
| 130 | +uv run --group megatron-bridge python -m torch.distributed.run \ |
| 131 | + --nproc-per-node $num_gpus examples/megatron/recipes/dit/pretrain_dit_model.py \ |
| 132 | + --config-file examples/megatron/recipes/dit/conf/my_config.yaml \ |
| 133 | + train.train_iters=20000 \ |
| 134 | + model.num_layers=32 |
| 135 | +``` |
125 | 136 |
|
| 137 | +**Note:** If you dedicate 100% of the data to training, you need to pass `dataset.use_train_split_for_val=true` to use a subset of training data for validation purposes. |
126 | 138 |
|
127 | | -### Mock Dataset |
| 139 | +```bash |
| 140 | +uv run --group megatron-bridge python -m torch.distributed.run \ |
| 141 | + --nproc-per-node $num_gpus examples/megatron/recipes/dit/pretrain_dit_model.py \ |
| 142 | + --config-file examples/megatron/recipes/dit/conf/my_config.yaml \ |
| 143 | + dataset.use_train_split_for_val=true |
| 144 | +``` |
128 | 145 |
|
129 | | -For performance measurement purposes you can use the mock dataset by passing the `--mock` argument. |
| 146 | +#### 🧪 Quick Start with Mock Dataset |
130 | 147 |
|
131 | | -``` bash |
| 148 | +If you want to run the code without having the dataset ready (for performance measurement purposes, for example), you can pass the `--mock` flag to activate a mock dataset. |
| 149 | +
|
| 150 | +```bash |
132 | 151 | uv run --group megatron-bridge python -m torch.distributed.run \ |
133 | | - --nproc-per-node $NUM_GPUS examples/megatron/recipes/dit/pretrain_dit_model.py\ |
134 | | - model.tensor_model_parallel_size=1 \ |
135 | | - model.pipeline_model_parallel_size=1 \ |
136 | | - model.context_parallel_size=1 \ |
137 | | - model.qkv_format=thd \ |
138 | | - model.num_attention_heads=16\ |
139 | | - model.vae_cache_folder=$CACHE_FOLDER\ |
140 | | - dataset.path=$DATA_FOLDER \ |
141 | | - dataset.task_encoder_seq_length=15360\ |
142 | | - dataset.packing_buffer_size=100\ |
143 | | - dataset.num_workers=20\ |
144 | | - checkpoint.save=$CHECKPOINT_FOLDER \ |
145 | | - checkpoint.load=$CHECKPOINT_FOLDER \ |
146 | | - checkpoint.load_optim=true \ |
147 | | - checkpoint.save_interval=1000 \ |
148 | | - train.eval_interval=1000\ |
149 | | - train.train_iters=10000\ |
150 | | - train.eval_iters=32 \ |
151 | | - train.global_batch_size=$NUM_GPUS\ |
152 | | - train.micro_batch_size=1\ |
153 | | - logger.log_interval=10\ |
154 | | - --mock |
| 152 | + --nproc-per-node $num_gpus examples/megatron/recipes/dit/pretrain_dit_model.py \ |
| 153 | + --config-file examples/megatron/recipes/dit/conf/dit_pretrain.yaml \ |
| 154 | + --mock |
155 | 155 | ``` |
| 156 | +
|
| 157 | +### 🎬 Inference |
| 158 | +
|
| 159 | +Once training completes, you can run inference using [inference_dit_model.py](https://github.com/NVIDIA-NeMo/DFM/blob/main/examples/megatron/recipes/dit/inference_dit_model.py). The script requires your trained model checkpoint (`--checkpoint_path`) and a path to save generated videos (`--video_save_path`). You can pass two optional arguments, `--t5_cache_dir` and `--tokenizer_cache_dir`, to avoid re-downloading artifacts if they are already downloaded. |
| 160 | +
|
| 161 | +```bash |
| 162 | +uv run --group megatron-bridge python -m torch.distributed.run --nproc-per-node $num_gpus \ |
| 163 | + examples/megatron/recipes/dit/inference_dit_model.py \ |
| 164 | + --t5_cache_dir $artifact_dir \ |
| 165 | + --tokenizer_cache_dir $tokenizer_cache_dir \ |
| 166 | + --tokenizer_model Cosmos-0.1-Tokenizer-CV4x8x8 \ |
| 167 | + --checkpoint_path $checkpoint_dir \ |
| 168 | + --num_video_frames 10 \ |
| 169 | + --height 240 \ |
| 170 | + --width 416 \ |
| 171 | + --video_save_path $save_path \ |
| 172 | + --prompt "$prompt" |
| 173 | +``` |
| 174 | +
|
| 175 | +--- |
| 176 | +
|
| 177 | +### ⚡ Parallelism Support |
| 178 | +
|
| 179 | +The table below shows current parallelism support for different model sizes: |
| 180 | +
|
| 181 | +| Model | Data Parallel | Tensor Parallel | Sequence Parallel | Context Parallel | |
| 182 | +|---|---|---|---|---| |
| 183 | +| **DiT-S (330M)** | TBD | TBD | TBD | TBD | |
| 184 | +| **DiT-L (450M)** | TBD | TBD | TBD| TBD | |
| 185 | +| **DiT-XL (700M)** | ✅ | ✅ | ✅ | ✅ | |
0 commit comments