Skip to content

Commit 644970d

Browse files
committed
update the DiT configs to be aligned with the original paper.
Signed-off-by: sajadn <[email protected]>
1 parent dee1153 commit 644970d

File tree

3 files changed

+23
-49
lines changed

3 files changed

+23
-49
lines changed

dfm/src/megatron/model/dit/dit_model_provider.py

Lines changed: 20 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import logging
1616
from dataclasses import dataclass
17-
from typing import Callable
1817

1918
import torch
2019
from megatron.bridge.models.model_provider import ModelProviderMixin
@@ -39,14 +38,14 @@ class DiTModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]):
3938
add_bias_linear: bool = False
4039
gated_linear_unit: bool = False
4140

42-
num_layers: int = 28
43-
hidden_size: int = 1152
41+
num_layers: int = 12
42+
hidden_size: int = 384
4443
max_img_h: int = 80
4544
max_img_w: int = 80
4645
max_frames: int = 34
4746
patch_spatial: int = 2
4847
patch_temporal: int = 1
49-
num_attention_heads: int = 16
48+
num_attention_heads: int = 6
5049
layernorm_epsilon = 1e-6
5150
normalization = "RMSNorm"
5251
add_bias_linear: bool = False
@@ -110,52 +109,27 @@ def configure_vae(self):
110109

111110

112111
@dataclass
113-
class DiT7BModelProvider(DiTModelProvider):
114-
hidden_size: int = 4096
115-
max_img_h: int = 240
116-
max_img_w: int = 240
117-
max_frames: int = 128
118-
num_attention_heads: int = 32
112+
class DiTBModelProvider(DiTModelProvider):
113+
"""DiT-B"""
119114

120-
apply_rope_fusion: bool = True # TODO: do we support this?
121-
additional_timestamp_channels = None # TODO: do we support this?
122-
vae_module: str = None
123-
vae_path: str = None
115+
num_layers: int = 12
116+
hidden_size: int = 768
117+
num_attention_heads: int = 12
124118

125119

126120
@dataclass
127-
class DiT14BModelProvider(DiTModelProvider):
128-
num_layers: int = 36
129-
hidden_size: int = 5120
130-
max_img_h: int = 240
131-
max_img_w: int = 240
132-
max_frames: int = 128
133-
num_attention_heads: int = 40
134-
apply_rope_fusion: bool = True
135-
layernorm_zero_centered_gamma: bool = False
136-
additional_timestamp_channels = None
137-
vae_module: str = None
138-
vae_path: str = None
139-
loss_add_logvar: bool = True
121+
class DiTLModelProvider(DiTModelProvider):
122+
"""DiT-L"""
123+
124+
num_layers: int = 24
125+
hidden_size: int = 1024
126+
num_attention_heads: int = 16
140127

141128

142129
@dataclass
143-
class DiTLlama30BConfig(DiTModelProvider):
144-
num_layers: int = 48
145-
hidden_size: int = 6144
146-
ffn_hidden_size: int = 16384
147-
num_attention_heads: int = 48
148-
num_query_groups: int = 8
149-
gated_linear_unit: int = True
150-
bias_activation_fusion: int = True
151-
activation_func: Callable = torch.nn.functional.silu
152-
layernorm_epsilon: float = 1e-5
153-
max_frames: int = 128
154-
max_img_h: int = 240
155-
max_img_w: int = 240
156-
init_method_std: float = 0.01
157-
add_bias_linear: bool = False
158-
seq_length: int = 256
159-
masked_softmax_fusion: bool = True
160-
persist_layer_norm: bool = True
161-
bias_dropout_fusion: bool = True
130+
class DiTXLModelProvider(DiTModelProvider):
131+
"""DiT-XL"""
132+
133+
num_layers: int = 28
134+
hidden_size: int = 1152
135+
num_attention_heads: int = 16

dfm/src/megatron/recipes/dit/dit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModuleConfig
3434
from dfm.src.megatron.data.dit.dit_mock_datamodule import DiTMockDataModuleConfig
35-
from dfm.src.megatron.model.dit.dit_model_provider import DiTModelProvider
35+
from dfm.src.megatron.model.dit.dit_model_provider import DiTModelProvider, DiTXLModelProvider
3636

3737

3838
def model_config(
@@ -57,7 +57,7 @@ def model_config(
5757
Returns:
5858
DiTModelProvider: Configuration for the DiT-S model.
5959
"""
60-
return DiTModelProvider(
60+
return DiTXLModelProvider(
6161
tensor_model_parallel_size=tensor_parallelism,
6262
pipeline_model_parallel_size=pipeline_parallelism,
6363
pipeline_dtype=pipeline_parallelism_dtype,

examples/megatron/recipes/dit/inference_dit_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020
import numpy as np
2121
import torch
22+
import wandb
2223
from einops import rearrange
2324
from megatron.core import parallel_state as ps
2425
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
2526
from transformers import T5EncoderModel, T5TokenizerFast
2627

27-
import wandb
2828
from dfm.src.common.tokenizers.cosmos.cosmos1.causal_video_tokenizer import CausalVideoTokenizer
2929
from dfm.src.common.utils.save_video import save_video
3030
from dfm.src.megatron.model.dit.edm.edm_pipeline import EDMPipeline

0 commit comments

Comments
 (0)