1414
1515import logging
1616from dataclasses import dataclass
17- from typing import Callable
1817
1918import torch
2019from megatron .bridge .models .model_provider import ModelProviderMixin
@@ -39,14 +38,14 @@ class DiTModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]):
3938 add_bias_linear : bool = False
4039 gated_linear_unit : bool = False
4140
42- num_layers : int = 28
43- hidden_size : int = 1152
41+ num_layers : int = 12
42+ hidden_size : int = 384
4443 max_img_h : int = 80
4544 max_img_w : int = 80
4645 max_frames : int = 34
4746 patch_spatial : int = 2
4847 patch_temporal : int = 1
49- num_attention_heads : int = 16
48+ num_attention_heads : int = 6
5049 layernorm_epsilon = 1e-6
5150 normalization = "RMSNorm"
5251 add_bias_linear : bool = False
@@ -110,52 +109,27 @@ def configure_vae(self):
110109
111110
112111@dataclass
113- class DiT7BModelProvider (DiTModelProvider ):
114- hidden_size : int = 4096
115- max_img_h : int = 240
116- max_img_w : int = 240
117- max_frames : int = 128
118- num_attention_heads : int = 32
112+ class DiTBModelProvider (DiTModelProvider ):
113+ """DiT-B"""
119114
120- apply_rope_fusion : bool = True # TODO: do we support this?
121- additional_timestamp_channels = None # TODO: do we support this?
122- vae_module : str = None
123- vae_path : str = None
115+ num_layers : int = 12
116+ hidden_size : int = 768
117+ num_attention_heads : int = 12
124118
125119
126120@dataclass
127- class DiT14BModelProvider (DiTModelProvider ):
128- num_layers : int = 36
129- hidden_size : int = 5120
130- max_img_h : int = 240
131- max_img_w : int = 240
132- max_frames : int = 128
133- num_attention_heads : int = 40
134- apply_rope_fusion : bool = True
135- layernorm_zero_centered_gamma : bool = False
136- additional_timestamp_channels = None
137- vae_module : str = None
138- vae_path : str = None
139- loss_add_logvar : bool = True
121+ class DiTLModelProvider (DiTModelProvider ):
122+ """DiT-L"""
123+
124+ num_layers : int = 24
125+ hidden_size : int = 1024
126+ num_attention_heads : int = 16
140127
141128
142129@dataclass
143- class DiTLlama30BConfig (DiTModelProvider ):
144- num_layers : int = 48
145- hidden_size : int = 6144
146- ffn_hidden_size : int = 16384
147- num_attention_heads : int = 48
148- num_query_groups : int = 8
149- gated_linear_unit : int = True
150- bias_activation_fusion : int = True
151- activation_func : Callable = torch .nn .functional .silu
152- layernorm_epsilon : float = 1e-5
153- max_frames : int = 128
154- max_img_h : int = 240
155- max_img_w : int = 240
156- init_method_std : float = 0.01
157- add_bias_linear : bool = False
158- seq_length : int = 256
159- masked_softmax_fusion : bool = True
160- persist_layer_norm : bool = True
161- bias_dropout_fusion : bool = True
130+ class DiTXLModelProvider (DiTModelProvider ):
131+ """DiT-XL"""
132+
133+ num_layers : int = 28
134+ hidden_size : int = 1152
135+ num_attention_heads : int = 16
0 commit comments