Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion dfm/src/common/utils/save_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,4 @@ def save_video(
"output_params": ["-f", "mp4"],
}

print("video_save_path", video_save_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good cleanup! Removing debug print statements keeps the output clean in production.

imageio.mimsave(video_save_path, grid, "mp4", **kwargs)
6 changes: 5 additions & 1 deletion dfm/src/megatron/data/common/diffusion_energon_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ def __post_init__(self):
self.sequence_length = self.dataset.seq_length

def build_datasets(self, context: DatasetBuildContext):
return self.dataset.train_dataloader(), self.dataset.val_dataloader(), self.dataset.test_dataloader()
return (
iter(self.dataset.train_dataloader()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvement! Wrapping dataloaders with iter() makes the interface more explicit and reduces potential confusion in downstream usage.

iter(self.dataset.val_dataloader()),
iter(self.dataset.val_dataloader()),
)


class DiffusionDataModule(EnergonMultiModalDataModule):
Expand Down
24 changes: 17 additions & 7 deletions dfm/src/megatron/data/common/diffusion_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,35 @@ def to_dict(self) -> dict:
def __add__(self, other: Any) -> int:
"""Adds the sequence length of this sample with another sample or integer."""
if isinstance(other, DiffusionSample):
# Combine the values of the two instances
return self.seq_len_q.item() + other.seq_len_q.item()
# Use padded length if available (for CP), otherwise use unpadded
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
other_len = other.seq_len_q_padded.item() if other.seq_len_q_padded is not None else other.seq_len_q.item()
return self_len + other_len
elif isinstance(other, int):
# Add an integer to the value
return self.seq_len_q.item() + other
# Use padded length if available (for CP), otherwise use unpadded
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
return self_len + other
raise NotImplementedError

def __radd__(self, other: Any) -> int:
"""Handles reverse addition for summing with integers."""
# This is called if sum or other operations start with a non-DiffusionSample object.
# e.g., sum([DiffusionSample(1), DiffusionSample(2)]) -> the 0 + DiffusionSample(1) calls __radd__.
if isinstance(other, int):
return self.seq_len_q.item() + other
# Use padded length if available (for CP), otherwise use unpadded
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
return self_len + other
raise NotImplementedError

def __lt__(self, other: Any) -> bool:
"""Compares this sample's sequence length with another sample or integer."""
if isinstance(other, DiffusionSample):
return self.seq_len_q.item() < other.seq_len_q.item()
# Use padded length if available (for CP), otherwise use unpadded
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
other_len = other.seq_len_q_padded.item() if other.seq_len_q_padded is not None else other.seq_len_q.item()
return self_len < other_len
elif isinstance(other, int):
return self.seq_len_q.item() < other
# Use padded length if available (for CP), otherwise use unpadded
self_len = self.seq_len_q_padded.item() if self.seq_len_q_padded is not None else self.seq_len_q.item()
return self_len < other
raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
self,
*args,
max_frames: int = None,
text_embedding_padding_size: int = 512,
text_embedding_max_length: int = 512,
seq_length: int = None,
patch_spatial: int = 2,
patch_temporal: int = 1,
Expand All @@ -65,7 +65,7 @@ def __init__(
):
super().__init__(*args, **kwargs)
self.max_frames = max_frames
self.text_embedding_padding_size = text_embedding_padding_size
self.text_embedding_max_length = text_embedding_max_length
self.seq_length = seq_length
self.patch_spatial = patch_spatial
self.patch_temporal = patch_temporal
Expand Down
32 changes: 0 additions & 32 deletions dfm/src/megatron/data/common/sequence_packing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,35 +71,3 @@ def first_fit_decreasing(seqlens: List[int], pack_size: int) -> List[List[int]]:
"""
sorted_seqlens = sorted(seqlens, reverse=True)
return first_fit(sorted_seqlens, pack_size)


def concat_pad(tensor_list, max_seq_length):
"""
Efficiently concatenates a list of tensors along the first dimension and pads with zeros
to reach max_seq_length.

Args:
tensor_list (list of torch.Tensor): List of tensors to concatenate and pad.
max_seq_length (int): The desired size of the first dimension of the output tensor.

Returns:
torch.Tensor: A tensor of shape [max_seq_length, ...], where ... represents the remaining dimensions.
"""
import torch

# Get common properties from the first tensor
other_shape = tensor_list[0].shape[1:]
dtype = tensor_list[0].dtype
device = tensor_list[0].device

# Initialize the result tensor with zeros
result = torch.zeros((max_seq_length, *other_shape), dtype=dtype, device=device)

current_index = 0
for tensor in tensor_list:
length = tensor.shape[0]
# Directly assign the tensor to the result tensor without checks
result[current_index : current_index + length] = tensor
current_index += length

return result
13 changes: 9 additions & 4 deletions dfm/src/megatron/data/dit/dit_mock_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def mock_batch(
seq_len_kv=seq_len_kv_packed,
seq_len_kv_padded=seq_len_kv_padded_packed,
latent_shape=torch.tensor([[C, T, H, W] for _ in range(number_packed_samples)], dtype=torch.int32),
pos_ids=pos_ids_packed,
pos_ids=pos_ids_packed.unsqueeze(0),
video_metadata=[{"caption": f"Mock video sample {i}"} for i in range(number_packed_samples)],
)

Expand All @@ -131,16 +131,19 @@ class DiTMockDataModuleConfig(DatasetProvider):
dataloader_type: str = "external"
task_encoder_seq_length: int = None
F_latents: int = 1
H_latents: int = 64
W_latents: int = 96
H_latents: int = 256
W_latents: int = 512
patch_spatial: int = 2
patch_temporal: int = 1
number_packed_samples: int = 3
number_packed_samples: int = 1
context_seq_len: int = 512
context_embeddings_dim: int = 1024

def __post_init__(self):
mock_ds = _MockDataset(length=1024)
kwargs = {}
if self.num_workers > 0:
kwargs["prefetch_factor"] = 8
self._train_dl = DataLoader(
mock_ds,
batch_size=self.micro_batch_size,
Expand All @@ -157,6 +160,8 @@ def __post_init__(self):
),
shuffle=False,
drop_last=False,
pin_memory=True,
**kwargs,
)
self._train_dl = iter(self._train_dl)
self.sequence_length = self.seq_length
Expand Down
9 changes: 4 additions & 5 deletions dfm/src/megatron/data/dit/dit_taskencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ class DiTTaskEncoder(DiffusionTaskEncoderWithSequencePacking):
Attributes:
cookers (list): A list of Cooker objects used for processing.
max_frames (int, optional): The maximum number of frames to consider from the video. Defaults to None.
text_embedding_padding_size (int): The padding size for text embeddings. Defaults to 512.
text_embedding_max_length (int): The maximum length for text embeddings. Defaults to 512.
Methods:
__init__(*args, max_frames=None, text_embedding_padding_size=512, **kwargs):
__init__(*args, max_frames=None, text_embedding_max_size=512, **kwargs):
Initializes the BasicDiffusionTaskEncoder with optional maximum frames and text embedding padding size.
encode_sample(sample: dict) -> dict:
Encodes a given sample dictionary containing video and text data.
Expand Down Expand Up @@ -71,7 +71,6 @@ def encode_sample(self, sample: dict) -> DiffusionSample:
// self.patch_spatial**2
// self.patch_temporal
)
is_image = T == 1

if seq_len > self.seq_length:
print(f"Skipping sample {sample['__key__']} because seq_len {seq_len} > self.seq_length {self.seq_length}")
Expand Down Expand Up @@ -100,8 +99,8 @@ def encode_sample(self, sample: dict) -> DiffusionSample:
t5_text_embeddings = torch.from_numpy(sample["pickle"]).to(torch.bfloat16)
t5_text_embeddings_seq_length = t5_text_embeddings.shape[0]

if t5_text_embeddings_seq_length > self.text_embedding_padding_size:
t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size]
if t5_text_embeddings_seq_length > self.text_embedding_max_length:
t5_text_embeddings = t5_text_embeddings[: self.text_embedding_max_length]
t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16)

pos_ids = rearrange(
Expand Down
25 changes: 18 additions & 7 deletions dfm/src/megatron/model/dit/dit_layer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

import torch
import torch.nn as nn
from megatron.core.jit import jit_fuser
from megatron.core.transformer.attention import (
SelfAttention,
SelfAttentionSubmodules,
)
from megatron.core.transformer.custom_layers.transformer_engine import (
Expand All @@ -41,7 +41,11 @@
from megatron.core.utils import make_viewless_tensor

# to be imported from common
from dfm.src.megatron.model.common.dit_attention import DiTCrossAttention, DiTCrossAttentionSubmodules
from dfm.src.megatron.model.common.dit_attention import (
DiTCrossAttention,
DiTCrossAttentionSubmodules,
DiTSelfAttention,
)


@dataclass
Expand Down Expand Up @@ -91,19 +95,24 @@ def __init__(

setattr(self.adaLN_modulation[-1].weight, "sequence_parallel", config.sequence_parallel)

@jit_fuser
def forward(self, timestep_emb):
return self.adaLN_modulation(timestep_emb).chunk(self.n_adaln_chunks, dim=-1)

@jit_fuser
def modulate(self, x, shift, scale):
return x * (1 + scale) + shift

@jit_fuser
def scale_add(self, residual, x, gate):
return residual + gate * x

@jit_fuser
def modulated_layernorm(self, x, shift, scale):
input_layernorm_output = self.ln(x).type_as(x)
return self.modulate(input_layernorm_output, shift, scale)

@jit_fuser
def scaled_modulated_layernorm(self, residual, x, gate, shift, scale):
hidden_states = self.scale_add(residual, x, gate)
shifted_pre_mlp_layernorm_output = self.modulated_layernorm(hidden_states, shift, scale)
Expand Down Expand Up @@ -156,7 +165,9 @@ def _replace_no_cp_submodules(submodules):
layer_number=layer_number,
)

self.adaLN = AdaLN(config=self.config, n_adaln_chunks=9 if self.cross_attention else 6)
self.adaLN = AdaLN(
config=self.config, n_adaln_chunks=9 if not isinstance(self.cross_attention, IdentityOp) else 6
)

def forward(
self,
Expand All @@ -176,7 +187,7 @@ def forward(
):
timestep_emb = attention_mask

if self.cross_attention:
if not isinstance(self.cross_attention, IdentityOp):
shift_full, scale_full, gate_full, shift_ca, scale_ca, gate_ca, shift_mlp, scale_mlp, gate_mlp = (
self.adaLN(timestep_emb)
)
Expand All @@ -192,7 +203,7 @@ def forward(
packed_seq_params=None if packed_seq_params is None else packed_seq_params["self_attention"],
)

if self.cross_attention:
if not isinstance(self.cross_attention, IdentityOp):
hidden_states, pre_cross_attn_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm(
residual=hidden_states,
x=attention_output,
Expand All @@ -210,7 +221,7 @@ def forward(
hidden_states, pre_mlp_layernorm_output_ada = self.adaLN.scaled_modulated_layernorm(
residual=hidden_states,
x=attention_output,
gate=gate_ca if self.cross_attention else gate_full,
gate=gate_ca if not isinstance(self.cross_attention, IdentityOp) else gate_full,
shift=shift_mlp,
scale=scale_mlp,
)
Expand All @@ -234,7 +245,7 @@ def get_dit_adaln_block_with_transformer_engine_spec() -> ModuleSpec:
module=DiTLayerWithAdaLN,
submodules=DiTWithAdaLNSubmodules(
full_self_attention=ModuleSpec(
module=SelfAttention,
module=DiTSelfAttention,
params=params,
submodules=SelfAttentionSubmodules(
linear_qkv=TEColumnParallelLinear,
Expand Down
66 changes: 20 additions & 46 deletions dfm/src/megatron/model/dit/dit_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import logging
from dataclasses import dataclass
from typing import Callable

import torch
from megatron.bridge.models.model_provider import ModelProviderMixin
Expand All @@ -39,14 +38,14 @@ class DiTModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]):
add_bias_linear: bool = False
gated_linear_unit: bool = False

num_layers: int = 28
hidden_size: int = 1152
num_layers: int = 12
hidden_size: int = 384
max_img_h: int = 80
max_img_w: int = 80
max_frames: int = 34
patch_spatial: int = 2
patch_temporal: int = 1
num_attention_heads: int = 16
num_attention_heads: int = 6
layernorm_epsilon = 1e-6
normalization = "RMSNorm"
add_bias_linear: bool = False
Expand Down Expand Up @@ -110,52 +109,27 @@ def configure_vae(self):


@dataclass
class DiT7BModelProvider(DiTModelProvider):
hidden_size: int = 4096
max_img_h: int = 240
max_img_w: int = 240
max_frames: int = 128
num_attention_heads: int = 32
class DiTBModelProvider(DiTModelProvider):
"""DiT-B"""

apply_rope_fusion: bool = True # TODO: do we support this?
additional_timestamp_channels = None # TODO: do we support this?
vae_module: str = None
vae_path: str = None
num_layers: int = 12
hidden_size: int = 768
num_attention_heads: int = 12


@dataclass
class DiT14BModelProvider(DiTModelProvider):
num_layers: int = 36
hidden_size: int = 5120
max_img_h: int = 240
max_img_w: int = 240
max_frames: int = 128
num_attention_heads: int = 40
apply_rope_fusion: bool = True
layernorm_zero_centered_gamma: bool = False
additional_timestamp_channels = None
vae_module: str = None
vae_path: str = None
loss_add_logvar: bool = True
class DiTLModelProvider(DiTModelProvider):
"""DiT-L"""

num_layers: int = 24
hidden_size: int = 1024
num_attention_heads: int = 16


@dataclass
class DiTLlama30BConfig(DiTModelProvider):
num_layers: int = 48
hidden_size: int = 6144
ffn_hidden_size: int = 16384
num_attention_heads: int = 48
num_query_groups: int = 8
gated_linear_unit: int = True
bias_activation_fusion: int = True
activation_func: Callable = torch.nn.functional.silu
layernorm_epsilon: float = 1e-5
max_frames: int = 128
max_img_h: int = 240
max_img_w: int = 240
init_method_std: float = 0.01
add_bias_linear: bool = False
seq_length: int = 256
masked_softmax_fusion: bool = True
persist_layer_norm: bool = True
bias_dropout_fusion: bool = True
class DiTXLModelProvider(DiTModelProvider):
"""DiT-XL"""

num_layers: int = 28
hidden_size: int = 1152
num_attention_heads: int = 16
3 changes: 1 addition & 2 deletions dfm/src/megatron/model/dit/dit_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def on_validation_start(self, state, batch, model):
num_steps=model.config.val_generation_num_steps,
is_negative_prompt=True if "neg_context_embeddings" in batch else False,
)
caption = batch["video_metadata"][0]["caption"]
caption = batch["video_metadata"][0]["caption"] if "caption" in batch["video_metadata"][0] else "no caption"
latent = latent[0, None, : batch["seq_len_q"][0]]
latent = rearrange(
latent,
Expand Down Expand Up @@ -157,7 +157,6 @@ def forward_step(self, state, batch, model, return_schedule_plan: bool = False):

check_for_nan_in_loss = state.cfg.rerun_state_machine.check_for_nan_in_loss
check_for_spiky_loss = state.cfg.rerun_state_machine.check_for_spiky_loss
# import pdb;pdb.set_trace()
straggler_timer = state.straggler_timer
with straggler_timer:
if parallel_state.is_pipeline_last_stage():
Expand Down
Loading