Skip to content

Commit 1061749

Browse files
sajadnabhinavg4
andauthored
Dit unit tests (#68)
* edm and data preprocess tests. Signed-off-by: sajadn <[email protected]> * Minor cleanings for DiT. Signed-off-by: Sajad Norouzi <[email protected]> * add dit unit test. Signed-off-by: Sajad Norouzi <[email protected]> * add iter to the DiffusionDataModule. Signed-off-by: sajadn <[email protected]> * add missing copyright. Signed-off-by: sajadn <[email protected]> * use 'no caption' if caption is not present. Signed-off-by: sajadn <[email protected]> * fix dit inference bug. Add wanbd to inference code. Signed-off-by: sajadn <[email protected]> * update the DiT configs to be aligned with the original paper. Signed-off-by: sajadn <[email protected]> * add wandb[video] and mediapy to uv. Signed-off-by: sajadn <[email protected]> * adjust pos_ids in mock_dataset to have batch dimension, fuse adaLN layers, use DiTSelfAttention. Signed-off-by: sajadn <[email protected]> * fix the diffusion sample size bug. Signed-off-by: sajadn <[email protected]> * fix broken tests. Signed-off-by: sajadn <[email protected]> --------- Signed-off-by: sajadn <[email protected]> Signed-off-by: Sajad Norouzi <[email protected]> Co-authored-by: Abhinav Garg <[email protected]>
1 parent 7270c9f commit 1061749

29 files changed

+2711
-237
lines changed

dfm/src/common/utils/save_video.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,4 @@ def save_video(
4444
"output_params": ["-f", "mp4"],
4545
}
4646

47-
print("video_save_path", video_save_path)
4847
imageio.mimsave(video_save_path, grid, "mp4", **kwargs)

dfm/src/megatron/data/common/diffusion_energon_datamodule.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ def __post_init__(self):
5555
self.sequence_length = self.dataset.seq_length
5656

5757
def build_datasets(self, context: DatasetBuildContext):
58-
return self.dataset.train_dataloader(), self.dataset.val_dataloader(), self.dataset.test_dataloader()
58+
return (
59+
iter(self.dataset.train_dataloader()),
60+
iter(self.dataset.val_dataloader()),
61+
iter(self.dataset.val_dataloader()),
62+
)
5963

6064

6165
class DiffusionDataModule(EnergonMultiModalDataModule):

dfm/src/megatron/data/common/diffusion_sample.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,25 +80,35 @@ def to_dict(self) -> dict:
8080
def __add__(self, other: Any) -> int:
8181
"""Adds the sequence length of this sample with another sample or integer."""
8282
if isinstance(other, DiffusionSample):
83-
# Combine the values of the two instances
84-
return self.seq_len_q.item() + other.seq_len_q.item()
83+
# Use padded length if available (for CP), otherwise use unpadded
84+
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
85+
other_len = other.seq_len_q_padded.item() if other.seq_len_q_padded is not None else other.seq_len_q.item()
86+
return self_len + other_len
8587
elif isinstance(other, int):
86-
# Add an integer to the value
87-
return self.seq_len_q.item() + other
88+
# Use padded length if available (for CP), otherwise use unpadded
89+
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
90+
return self_len + other
8891
raise NotImplementedError
8992

9093
def __radd__(self, other: Any) -> int:
9194
"""Handles reverse addition for summing with integers."""
9295
# This is called if sum or other operations start with a non-DiffusionSample object.
9396
# e.g., sum([DiffusionSample(1), DiffusionSample(2)]) -> the 0 + DiffusionSample(1) calls __radd__.
9497
if isinstance(other, int):
95-
return self.seq_len_q.item() + other
98+
# Use padded length if available (for CP), otherwise use unpadded
99+
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
100+
return self_len + other
96101
raise NotImplementedError
97102

98103
def __lt__(self, other: Any) -> bool:
99104
"""Compares this sample's sequence length with another sample or integer."""
100105
if isinstance(other, DiffusionSample):
101-
return self.seq_len_q.item() < other.seq_len_q.item()
106+
# Use padded length if available (for CP), otherwise use unpadded
107+
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
108+
other_len = other.seq_len_q_padded.item() if other.seq_len_q_padded is not None else other.seq_len_q.item()
109+
return self_len < other_len
102110
elif isinstance(other, int):
103-
return self.seq_len_q.item() < other
111+
# Use padded length if available (for CP), otherwise use unpadded
112+
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
113+
return self_len < other
104114
raise NotImplementedError

dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
self,
5757
*args,
5858
max_frames: int = None,
59-
text_embedding_padding_size: int = 512,
59+
text_embedding_max_length: int = 512,
6060
seq_length: int = None,
6161
patch_spatial: int = 2,
6262
patch_temporal: int = 1,
@@ -65,7 +65,7 @@ def __init__(
6565
):
6666
super().__init__(*args, **kwargs)
6767
self.max_frames = max_frames
68-
self.text_embedding_padding_size = text_embedding_padding_size
68+
self.text_embedding_max_length = text_embedding_max_length
6969
self.seq_length = seq_length
7070
self.patch_spatial = patch_spatial
7171
self.patch_temporal = patch_temporal

dfm/src/megatron/data/common/sequence_packing_utils.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -71,35 +71,3 @@ def first_fit_decreasing(seqlens: List[int], pack_size: int) -> List[List[int]]:
7171
"""
7272
sorted_seqlens = sorted(seqlens, reverse=True)
7373
return first_fit(sorted_seqlens, pack_size)
74-
75-
76-
def concat_pad(tensor_list, max_seq_length):
77-
"""
78-
Efficiently concatenates a list of tensors along the first dimension and pads with zeros
79-
to reach max_seq_length.
80-
81-
Args:
82-
tensor_list (list of torch.Tensor): List of tensors to concatenate and pad.
83-
max_seq_length (int): The desired size of the first dimension of the output tensor.
84-
85-
Returns:
86-
torch.Tensor: A tensor of shape [max_seq_length, ...], where ... represents the remaining dimensions.
87-
"""
88-
import torch
89-
90-
# Get common properties from the first tensor
91-
other_shape = tensor_list[0].shape[1:]
92-
dtype = tensor_list[0].dtype
93-
device = tensor_list[0].device
94-
95-
# Initialize the result tensor with zeros
96-
result = torch.zeros((max_seq_length, *other_shape), dtype=dtype, device=device)
97-
98-
current_index = 0
99-
for tensor in tensor_list:
100-
length = tensor.shape[0]
101-
# Directly assign the tensor to the result tensor without checks
102-
result[current_index : current_index + length] = tensor
103-
current_index += length
104-
105-
return result

dfm/src/megatron/data/dit/dit_mock_datamodule.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def mock_batch(
113113
seq_len_kv=seq_len_kv_packed,
114114
seq_len_kv_padded=seq_len_kv_padded_packed,
115115
latent_shape=torch.tensor([[C, T, H, W] for _ in range(number_packed_samples)], dtype=torch.int32),
116-
pos_ids=pos_ids_packed,
116+
pos_ids=pos_ids_packed.unsqueeze(0),
117117
video_metadata=[{"caption": f"Mock video sample {i}"} for i in range(number_packed_samples)],
118118
)
119119

@@ -131,16 +131,19 @@ class DiTMockDataModuleConfig(DatasetProvider):
131131
dataloader_type: str = "external"
132132
task_encoder_seq_length: int = None
133133
F_latents: int = 1
134-
H_latents: int = 64
135-
W_latents: int = 96
134+
H_latents: int = 256
135+
W_latents: int = 512
136136
patch_spatial: int = 2
137137
patch_temporal: int = 1
138-
number_packed_samples: int = 3
138+
number_packed_samples: int = 1
139139
context_seq_len: int = 512
140140
context_embeddings_dim: int = 1024
141141

142142
def __post_init__(self):
143143
mock_ds = _MockDataset(length=1024)
144+
kwargs = {}
145+
if self.num_workers > 0:
146+
kwargs["prefetch_factor"] = 8
144147
self._train_dl = DataLoader(
145148
mock_ds,
146149
batch_size=self.micro_batch_size,
@@ -157,6 +160,8 @@ def __post_init__(self):
157160
),
158161
shuffle=False,
159162
drop_last=False,
163+
pin_memory=True,
164+
**kwargs,
160165
)
161166
self._train_dl = iter(self._train_dl)
162167
self.sequence_length = self.seq_length

dfm/src/megatron/data/dit/dit_taskencoder.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ class DiTTaskEncoder(DiffusionTaskEncoderWithSequencePacking):
3131
Attributes:
3232
cookers (list): A list of Cooker objects used for processing.
3333
max_frames (int, optional): The maximum number of frames to consider from the video. Defaults to None.
34-
text_embedding_padding_size (int): The padding size for text embeddings. Defaults to 512.
34+
text_embedding_max_length (int): The maximum length for text embeddings. Defaults to 512.
3535
Methods:
36-
__init__(*args, max_frames=None, text_embedding_padding_size=512, **kwargs):
36+
__init__(*args, max_frames=None, text_embedding_max_size=512, **kwargs):
3737
Initializes the BasicDiffusionTaskEncoder with optional maximum frames and text embedding padding size.
3838
encode_sample(sample: dict) -> dict:
3939
Encodes a given sample dictionary containing video and text data.
@@ -71,7 +71,6 @@ def encode_sample(self, sample: dict) -> DiffusionSample:
7171
// self.patch_spatial**2
7272
// self.patch_temporal
7373
)
74-
is_image = T == 1
7574

7675
if seq_len > self.seq_length:
7776
print(f"Skipping sample {sample['__key__']} because seq_len {seq_len} > self.seq_length {self.seq_length}")
@@ -100,8 +99,8 @@ def encode_sample(self, sample: dict) -> DiffusionSample:
10099
t5_text_embeddings = torch.from_numpy(sample["pickle"]).to(torch.bfloat16)
101100
t5_text_embeddings_seq_length = t5_text_embeddings.shape[0]
102101

103-
if t5_text_embeddings_seq_length > self.text_embedding_padding_size:
104-
t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size]
102+
if t5_text_embeddings_seq_length > self.text_embedding_max_length:
103+
t5_text_embeddings = t5_text_embeddings[: self.text_embedding_max_length]
105104
t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16)
106105

107106
pos_ids = rearrange(

dfm/src/megatron/model/dit/dit_layer_spec.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
import torch
2222
import torch.nn as nn
23+
from megatron.core.jit import jit_fuser
2324
from megatron.core.transformer.attention import (
24-
SelfAttention,
2525
SelfAttentionSubmodules,
2626
)
2727
from megatron.core.transformer.custom_layers.transformer_engine import (
@@ -41,7 +41,11 @@
4141
from megatron.core.utils import make_viewless_tensor
4242

4343
# to be imported from common
44-
from dfm.src.megatron.model.common.dit_attention import DiTCrossAttention, DiTCrossAttentionSubmodules
44+
from dfm.src.megatron.model.common.dit_attention import (
45+
DiTCrossAttention,
46+
DiTCrossAttentionSubmodules,
47+
DiTSelfAttention,
48+
)
4549

4650

4751
@dataclass
@@ -91,19 +95,24 @@ def __init__(
9195

9296
setattr(self.adaLN_modulation[-1].weight, "sequence_parallel", config.sequence_parallel)
9397

98+
@jit_fuser
9499
def forward(self, timestep_emb):
95100
return self.adaLN_modulation(timestep_emb).chunk(self.n_adaln_chunks, dim=-1)
96101

102+
@jit_fuser
97103
def modulate(self, x, shift, scale):
98104
return x * (1 + scale) + shift
99105

106+
@jit_fuser
100107
def scale_add(self, residual, x, gate):
101108
return residual + gate * x
102109

110+
@jit_fuser
103111
def modulated_layernorm(self, x, shift, scale):
104112
input_layernorm_output = self.ln(x).type_as(x)
105113
return self.modulate(input_layernorm_output, shift, scale)
106114

115+
@jit_fuser
107116
def scaled_modulated_layernorm(self, residual, x, gate, shift, scale):
108117
hidden_states = self.scale_add(residual, x, gate)
109118
shifted_pre_mlp_layernorm_output = self.modulated_layernorm(hidden_states, shift, scale)
@@ -156,7 +165,9 @@ def _replace_no_cp_submodules(submodules):
156165
layer_number=layer_number,
157166
)
158167

159-
self.adaLN = AdaLN(config=self.config, n_adaln_chunks=9 if self.cross_attention else 6)
168+
self.adaLN = AdaLN(
169+
config=self.config, n_adaln_chunks=9 if not isinstance(self.cross_attention, IdentityOp) else 6
170+
)
160171

161172
def forward(
162173
self,
@@ -176,7 +187,7 @@ def forward(
176187
):
177188
timestep_emb = attention_mask
178189

179-
if self.cross_attention:
190+
if not isinstance(self.cross_attention, IdentityOp):
180191
shift_full, scale_full, gate_full, shift_ca, scale_ca, gate_ca, shift_mlp, scale_mlp, gate_mlp = (
181192
self.adaLN(timestep_emb)
182193
)
@@ -192,7 +203,7 @@ def forward(
192203
packed_seq_params=None if packed_seq_params is None else packed_seq_params["self_attention"],
193204
)
194205

195-
if self.cross_attention:
206+
if not isinstance(self.cross_attention, IdentityOp):
196207
hidden_states, pre_cross_attn_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm(
197208
residual=hidden_states,
198209
x=attention_output,
@@ -210,7 +221,7 @@ def forward(
210221
hidden_states, pre_mlp_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm(
211222
residual=hidden_states,
212223
x=attention_output,
213-
gate=gate_ca if self.cross_attention else gate_full,
224+
gate=gate_ca if not isinstance(self.cross_attention, IdentityOp) else gate_full,
214225
shift=shift_mlp,
215226
scale=scale_mlp,
216227
)
@@ -234,7 +245,7 @@ def get_dit_adaln_block_with_transformer_engine_spec() -> ModuleSpec:
234245
module=DiTLayerWithAdaLN,
235246
submodules=DiTWithAdaLNSubmodules(
236247
full_self_attention=ModuleSpec(
237-
module=SelfAttention,
248+
module=DiTSelfAttention,
238249
params=params,
239250
submodules=SelfAttentionSubmodules(
240251
linear_qkv=TEColumnParallelLinear,

dfm/src/megatron/model/dit/dit_model_provider.py

Lines changed: 20 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import logging
1616
from dataclasses import dataclass
17-
from typing import Callable
1817

1918
import torch
2019
from megatron.bridge.models.model_provider import ModelProviderMixin
@@ -39,14 +38,14 @@ class DiTModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]):
3938
add_bias_linear: bool = False
4039
gated_linear_unit: bool = False
4140

42-
num_layers: int = 28
43-
hidden_size: int = 1152
41+
num_layers: int = 12
42+
hidden_size: int = 384
4443
max_img_h: int = 80
4544
max_img_w: int = 80
4645
max_frames: int = 34
4746
patch_spatial: int = 2
4847
patch_temporal: int = 1
49-
num_attention_heads: int = 16
48+
num_attention_heads: int = 6
5049
layernorm_epsilon = 1e-6
5150
normalization = "RMSNorm"
5251
add_bias_linear: bool = False
@@ -110,52 +109,27 @@ def configure_vae(self):
110109

111110

112111
@dataclass
113-
class DiT7BModelProvider(DiTModelProvider):
114-
hidden_size: int = 4096
115-
max_img_h: int = 240
116-
max_img_w: int = 240
117-
max_frames: int = 128
118-
num_attention_heads: int = 32
112+
class DiTBModelProvider(DiTModelProvider):
113+
"""DiT-B"""
119114

120-
apply_rope_fusion: bool = True # TODO: do we support this?
121-
additional_timestamp_channels = None # TODO: do we support this?
122-
vae_module: str = None
123-
vae_path: str = None
115+
num_layers: int = 12
116+
hidden_size: int = 768
117+
num_attention_heads: int = 12
124118

125119

126120
@dataclass
127-
class DiT14BModelProvider(DiTModelProvider):
128-
num_layers: int = 36
129-
hidden_size: int = 5120
130-
max_img_h: int = 240
131-
max_img_w: int = 240
132-
max_frames: int = 128
133-
num_attention_heads: int = 40
134-
apply_rope_fusion: bool = True
135-
layernorm_zero_centered_gamma: bool = False
136-
additional_timestamp_channels = None
137-
vae_module: str = None
138-
vae_path: str = None
139-
loss_add_logvar: bool = True
121+
class DiTLModelProvider(DiTModelProvider):
122+
"""DiT-L"""
123+
124+
num_layers: int = 24
125+
hidden_size: int = 1024
126+
num_attention_heads: int = 16
140127

141128

142129
@dataclass
143-
class DiTLlama30BConfig(DiTModelProvider):
144-
num_layers: int = 48
145-
hidden_size: int = 6144
146-
ffn_hidden_size: int = 16384
147-
num_attention_heads: int = 48
148-
num_query_groups: int = 8
149-
gated_linear_unit: int = True
150-
bias_activation_fusion: int = True
151-
activation_func: Callable = torch.nn.functional.silu
152-
layernorm_epsilon: float = 1e-5
153-
max_frames: int = 128
154-
max_img_h: int = 240
155-
max_img_w: int = 240
156-
init_method_std: float = 0.01
157-
add_bias_linear: bool = False
158-
seq_length: int = 256
159-
masked_softmax_fusion: bool = True
160-
persist_layer_norm: bool = True
161-
bias_dropout_fusion: bool = True
130+
class DiTXLModelProvider(DiTModelProvider):
131+
"""DiT-XL"""
132+
133+
num_layers: int = 28
134+
hidden_size: int = 1152
135+
num_attention_heads: int = 16

dfm/src/megatron/model/dit/dit_step.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def on_validation_start(self, state, batch, model):
5050
num_steps=model.config.val_generation_num_steps,
5151
is_negative_prompt=True if "neg_context_embeddings" in batch else False,
5252
)
53-
caption = batch["video_metadata"][0]["caption"]
53+
caption = batch["video_metadata"][0]["caption"] if "caption" in batch["video_metadata"][0] else "no caption"
5454
latent = latent[0, None, : batch["seq_len_q"][0]]
5555
latent = rearrange(
5656
latent,
@@ -157,7 +157,6 @@ def forward_step(self, state, batch, model, return_schedule_plan: bool = False):
157157

158158
check_for_nan_in_loss = state.cfg.rerun_state_machine.check_for_nan_in_loss
159159
check_for_spiky_loss = state.cfg.rerun_state_machine.check_for_spiky_loss
160-
# import pdb;pdb.set_trace()
161160
straggler_timer = state.straggler_timer
162161
with straggler_timer:
163162
if parallel_state.is_pipeline_last_stage():

0 commit comments

Comments
 (0)