Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
75ed5b2
first commit
Oct 30, 2025
2ebfd50
workable code
Oct 30, 2025
7b834f0
workable thd
Oct 31, 2025
2152abd
clean up, remove all CP for sbhd, CP now is only for thd
Oct 31, 2025
389a037
run outside of Mbridge
Oct 31, 2025
daac350
Update example scripts and add new data module for multimodal datasets
abhinavg4 Nov 3, 2025
d5d0106
workable code before refactoring
Nov 3, 2025
c4f5160
Merge remote-tracking branch 'origin/huvu/mcore_wan' into huvu/mcore_wan
Nov 3, 2025
0430384
refactor attention submodules + reorder files locations
Nov 4, 2025
dfff86b
update refactor
Nov 4, 2025
abbaa2a
update refactor
Nov 4, 2025
c59f6a2
reorganize files
Nov 4, 2025
0b91a1c
reorganize files
Nov 4, 2025
aa20504
refactoring code
Nov 5, 2025
d5f58c9
add README for perf test
Nov 5, 2025
9b8e4fb
using vae, t5, scheduler from Diffusers
Nov 5, 2025
7f414ae
update repo, remove Wan's Github moduels
Nov 5, 2025
62a518f
Merge remote-tracking branch 'origin/main' into huvu/mcore_wan
Nov 6, 2025
2de5934
fix Ruff
Nov 6, 2025
6b46a7f
fix ruff + copyright
Nov 6, 2025
c1d8923
fix Ruff + Lint
Nov 6, 2025
e8de1ae
fix Ruff + Lint
Nov 6, 2025
287ad34
fix Ruff + Lint
Nov 6, 2025
4464fd2
fix Ruff + Lint
Nov 6, 2025
547339a
fix Ruff + Lint
Nov 6, 2025
9cd082b
fix Ruff + Lint
Nov 6, 2025
4514eee
fix Ruff + Lint
Nov 6, 2025
acd430d
fix Ruff + Lint
Nov 6, 2025
19c0c29
Merge remote-tracking branch 'origin/main' into huvu/mcore_wan
Nov 6, 2025
a147258
merged main + address comments
Nov 6, 2025
f3828b0
remove example_commands.md, Google waits until mid Nov
Nov 6, 2025
4727447
refactor inference_configs + mockdatamodule
Nov 6, 2025
8f49e23
add dit_embeddings.py
Nov 6, 2025
4766b1b
fix lint ruff
Nov 6, 2025
c4004ea
add 'average_gradients_across_tp_domain' to torch.nn for when running…
Nov 7, 2025
c14001d
merge from main
Nov 7, 2025
e332cb2
add english negative prompt
Nov 7, 2025
bc03727
fix ruff lint
Nov 7, 2025
d7c1acb
Update uv.lock for deps: diffusers==0.35.1, easydict, imageio
Nov 7, 2025
c525013
update dfm/src/megatron/data/dit
Nov 7, 2025
0f57585
change english negative prompt
Nov 8, 2025
d17286d
seem to workable seq_packing
Nov 10, 2025
e936907
refactor with Sajad's PR - DiT data to common dir
Nov 10, 2025
66796b5
fix Ruff, lint
Nov 10, 2025
7d8e64f
fix Ruff, lint
Nov 10, 2025
6263299
fix Ruff, lint
Nov 10, 2025
377ff5b
workable mock datamodule (doesn't need setting path); updated trainin…
Nov 11, 2025
0ca76a8
merge main
Nov 11, 2025
d8550c4
bring wan_task encoders features to common, sharing with dit
Nov 11, 2025
a13d0c0
lint, ruff
Nov 11, 2025
39b0e73
lint, ruff
Nov 11, 2025
4647d89
lint, ruff
Nov 11, 2025
174bb7b
fix CP error (input of thd_split_inputs_cp to be cu_seqlens_q_padded …
Nov 12, 2025
462638a
udpate README_perf_test.md
Nov 12, 2025
f5c10a1
fix lint, ruff
Nov 12, 2025
0b0058f
update uv.lock, merge main
Nov 12, 2025
13968fc
update uv.lock, merge main
Nov 12, 2025
46aa6d8
uv.lock
Nov 12, 2025
6b553ec
uv.lock
Nov 12, 2025
b1c41fc
uv.lock
Nov 12, 2025
681145b
update uv.lock [using ci]
pablo-garay Nov 12, 2025
7ad788e
Performance improvements to Wan
parthmannan Nov 13, 2025
0fd0e27
Perf optimizations
parthmannan Nov 14, 2025
0886941
Merge remote-tracking branch 'origin/huvu/mcore_wan' into pmannan/dfm…
parthmannan Nov 14, 2025
6489c79
Merge branch 'main' of github.com:NVIDIA-NeMo/DFM into pmannan/dfm_perf
parthmannan Nov 14, 2025
fd373f9
Tiny fix
parthmannan Nov 14, 2025
6a93703
Remove CP disable as packed sequences not supported
parthmannan Nov 15, 2025
345b53e
Fix comment
parthmannan Nov 15, 2025
c9e55bb
Minor fixes. Revert video_latent comparison
parthmannan Nov 17, 2025
3504d3f
Fix missed check
parthmannan Nov 17, 2025
b2fef2f
Lint fix
parthmannan Nov 17, 2025
b5ac649
H100 mock pretraining perf config
parthmannan Nov 18, 2025
52083ac
Rename config file
parthmannan Nov 18, 2025
fef3196
Lint check
parthmannan Nov 18, 2025
601a63a
Adding GB200 perf config
parthmannan Nov 18, 2025
bc191b2
GB300 perf config
parthmannan Nov 19, 2025
992c8fb
Refactor Energon data module to return wrapped dataloaders and add En…
abhinavg4 Nov 19, 2025
55e6133
Merge branch 'main' into pmannan/dfm_perf
abhinavg4 Nov 19, 2025
3954ddd
Enhance DiffusionTaskEncoder to handle None attributes in stacking an…
abhinavg4 Nov 19, 2025
45aff1b
Refactor data processing in dit_data_step to simplify batch retrieval…
abhinavg4 Nov 19, 2025
5c02148
Merge branch 'main' of github.com:NVIDIA-NeMo/DFM into pmannan/dfm_perf
parthmannan Nov 20, 2025
6aff745
Add op fusions
parthmannan Nov 20, 2025
7a2ec65
Update H100 config
parthmannan Nov 20, 2025
b90fc5a
Fix lint
parthmannan Nov 20, 2025
92abc8d
Merge branch 'main' of github.com:NVIDIA-NeMo/DFM into pmannan/dfm_perf
parthmannan Nov 20, 2025
78ef7b6
Resolve conflict
parthmannan Nov 20, 2025
1a2d662
Fix for mock dataloader test
parthmannan Nov 20, 2025
1b055c6
Fix Dummyiter
parthmannan Nov 20, 2025
04f6c14
Fix test
parthmannan Nov 20, 2025
7247bc7
Make RoPE test only GPU
parthmannan Nov 21, 2025
b17a40d
Rope cuda fix
parthmannan Nov 21, 2025
7fedda7
Merge branch 'main' of github.com:NVIDIA-NeMo/DFM into pmannan/dfm_perf
parthmannan Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions dfm/src/megatron/data/common/base_energon_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def train_dataloader(self) -> Any:
train_dataset = self.datasets_provider(worker_config, split="train")
energon_dataloader = get_savable_loader(train_dataset, worker_config=worker_config)
self.train_dataloader_object = energon_dataloader
return self.train_dataloader_object
return EnergonDataloader(self.train_dataloader_object)

def val_dataloader(self):
"""
Expand Down Expand Up @@ -233,7 +233,7 @@ def val_dataloader(self):
val_dataset = self.datasets_provider(worker_config, split="val")
energon_loader = get_savable_loader(val_dataset, worker_config=worker_config)
self.val_dataloader_object = energon_loader
return self.val_dataloader_object
return EnergonDataloader(self.val_dataloader_object)

def test_dataloader(self) -> None:
"""
Expand Down Expand Up @@ -337,3 +337,26 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
consumed_samples=consumed_samples,
consistency_check=False,
)


class EnergonDataloader:
"""A wrapper to use Megatron Energon dataloader with the Megatron-LM training loop."""

def __init__(self, dataloader):
self._dataloader = dataloader
self._iter = iter(cyclic_iter(dataloader))

def __next__(self):
return self._iter.__next__()

def __iter__(self):
return self._iter.__iter__()

def save_state(self):
return self._dataloader.save_state_rank()


def cyclic_iter(iter):
while True:
for x in iter:
yield x
10 changes: 8 additions & 2 deletions dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,16 @@ def pack_selected_samples(self, samples: List[DiffusionSample]) -> DiffusionSamp
"""Construct a new Diffusion sample by concatenating the sequences."""

def stack(attr):
return torch.stack([getattr(sample, attr) for sample in samples], dim=0)
if hasattr(samples[0], attr) and getattr(samples[0], attr) is not None:
return torch.stack([getattr(sample, attr) for sample in samples], dim=0)
else:
return None

def cat(attr):
return torch.cat([getattr(sample, attr) for sample in samples], dim=0)
if hasattr(samples[0], attr) and getattr(samples[0], attr) is not None:
return torch.cat([getattr(sample, attr) for sample in samples], dim=0)
else:
return None

return DiffusionSample(
__key__=",".join([s.__key__ for s in samples]),
Expand Down
7 changes: 6 additions & 1 deletion dfm/src/megatron/data/wan/wan_mock_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,15 @@ class WanMockDataModuleConfig(DatasetProvider):
W_latents: int = 60
patch_spatial: int = 2
patch_temporal: int = 1
number_packed_samples: int = 3
number_packed_samples: int = 1
context_seq_len: int = 512
context_embeddings_dim: int = 4096

def __post_init__(self):
mock_ds = _MockDataset(length=1024)
kwargs = {}
if self.num_workers > 0:
kwargs["prefetch_factor"] = 8
self._train_dl = DataLoader(
mock_ds,
batch_size=self.micro_batch_size,
Expand All @@ -135,6 +138,8 @@ def __post_init__(self):
),
shuffle=False,
drop_last=False,
pin_memory=True,
**kwargs,
)
self._train_dl = iter(self._train_dl)
self.sequence_length = self.seq_length
Expand Down
4 changes: 3 additions & 1 deletion dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ def training_step(
packed_seq_params["self_attention"].cu_seqlens_q_padded,
parallel_state.get_context_parallel_group(),
)
# TODO (pmannan): Disable CP for CrossAttention as KV context is small.
# We don't need to split context embeddings across context parallelism
# if we disable context parallelism for cross-attention
context_embeddings = thd_split_inputs_cp(
context_embeddings,
packed_seq_params["cross_attention"].cu_seqlens_kv_padded,
Expand Down Expand Up @@ -261,5 +264,4 @@ def training_step(
context=context_embeddings,
packed_seq_params=packed_seq_params,
)

return hidden_states
6 changes: 2 additions & 4 deletions dfm/src/megatron/model/wan/rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(self, dim_head, max_position_len):
],
dim=1,
)
if torch.cuda.is_available():
self.freqs = self.freqs.cuda()

def rope_params(self, max_position_len, dim_head, theta=10000):
assert dim_head % 2 == 0
Expand All @@ -41,10 +43,6 @@ def rope_params(self, max_position_len, dim_head, theta=10000):
return freqs

def forward(self, n_head, dim_head, cu_seqlens_q_padded, grid_sizes, device):
self.freqs = self.freqs.to(
device,
)

n, c = n_head, dim_head // 2

# split freqs
Expand Down
68 changes: 46 additions & 22 deletions dfm/src/megatron/model/wan/wan_layer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# pylint: disable=C0115,C0116,C0301

import copy
from dataclasses import dataclass
from typing import Optional, Union

Expand Down Expand Up @@ -65,10 +66,16 @@ def __init__(self, config: TransformerConfig):

setattr(self.modulation, "sequence_parallel", config.sequence_parallel)

@jit_fuser
def forward(self, timestep_emb):
e = (self.modulation + timestep_emb).chunk(6, dim=1)
e = (self.modulation + timestep_emb).transpose(0, 1)
e = e.chunk(6, dim=0)
return e

@jit_fuser
def normalize_modulate(self, norm, hidden_states, shift, scale):
return self.modulate(norm(hidden_states), shift, scale)

@jit_fuser
def modulate(self, x, shift, scale):
return x * (1 + scale) + shift
Expand Down Expand Up @@ -96,19 +103,31 @@ def __init__(
pg_collection: Optional[ProcessGroupCollection] = None,
vp_stage: Optional[int] = None,
):
def _replace_no_cp_submodules(submodules):
modified_submods = copy.deepcopy(submodules)
modified_submods.cross_attention = IdentityOp
return modified_submods

# Replace any submodules that will have CP disabled and build them manually later after TransformerLayer init.
# modified_submods = _replace_no_cp_submodules(submodules)
super().__init__(
config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout
)

# # TODO: Override Cross Attention to disable TP Comm overlap as well. ???
# # Not disabling will attempt re-use of buffer size same as Q and lead to incorrect tensor shapes.
# cp_override_config = copy.deepcopy(config)
# cp_override_config.tp_comm_overlap = False
# self.cross_attention = build_module(
# submodules.cross_attention,
# config=cp_override_config,
# layer_number=layer_number,
# )
# TODO (pmannan): Override Cross Attention to disable CP.
# Disable TP Comm overlap as well. Not disabling will attempt re-use of buffer size same as
# Q and lead to incorrect tensor shapes.
# if submodules.cross_attention != IdentityOp:
# cp_override_config = copy.deepcopy(config)
# cp_override_config.context_parallel_size = 1
# cp_override_config.tp_comm_overlap = False
# self.cross_attention = build_module(
# submodules.cross_attention,
# config=cp_override_config,
# layer_number=layer_number,
# )
# else:
# self.cross_attention = None

self.full_self_attention = build_module(
submodules.full_self_attention,
Expand Down Expand Up @@ -148,6 +167,10 @@ def _mark_trainable_params_for_tp_grad_avg(self, modules: Optional[list] = None)
if isinstance(param, nn.Parameter) and param.requires_grad:
setattr(param, "average_gradients_across_tp_domain", True)

@jit_fuser
def add_residual(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
return x + residual

def forward(
self,
hidden_states,
Expand All @@ -169,19 +192,13 @@ def forward(
rope_emb = rotary_pos_emb

shift_full, scale_full, gate_full, shift_mlp, scale_mlp, gate_mlp = self.adaLN(timestep_emb)
# transpose to bring it to [1, b, ...] format
shift_full = shift_full.transpose(0, 1)
scale_full = scale_full.transpose(0, 1)
gate_full = gate_full.transpose(0, 1)
shift_mlp = shift_mlp.transpose(0, 1)
scale_mlp = scale_mlp.transpose(0, 1)
gate_mlp = gate_mlp.transpose(0, 1)

# ******************************************** full self attention *******************************************

# adaLN with scale + shift + gate
pre_full_attn_layernorm_output_ada = self.adaLN.modulate(
self.norm1(hidden_states),
pre_full_attn_layernorm_output_ada = self.adaLN.normalize_modulate(
self.norm1,
hidden_states,
shift=shift_full,
scale=scale_full,
)
Expand All @@ -201,6 +218,12 @@ def forward(

# ******************************************** cross attention ******************************************************

# TODO (pmannan): Disable CP for CrossAttention as KV context is small.
# But needs better support for packed sequences and padding to ensure correct calculations
# packed_seq_params['cross_attention'].cu_seqlens_q = torch.tensor(
# [0, hidden_states.shape[0]],
# device=packed_seq_params['cross_attention'].cu_seqlens_kv.device,
# dtype=torch.int32)
attention_output, bias = self.cross_attention(
self.norm3(hidden_states),
attention_mask=context_mask,
Expand All @@ -210,12 +233,13 @@ def forward(
if bias is not None:
attention_output = attention_output + bias

hidden_states = hidden_states + attention_output
hidden_states = self.add_residual(hidden_states, attention_output)

# ******************************************** mlp ******************************************************

pre_mlp_layernorm_output_ada = self.adaLN.modulate(
self.norm2(hidden_states),
pre_mlp_layernorm_output_ada = self.adaLN.normalize_modulate(
self.norm2,
hidden_states,
shift=shift_mlp,
scale=scale_mlp,
)
Expand Down
2 changes: 2 additions & 0 deletions dfm/src/megatron/model/wan/wan_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]):
bf16: bool = False
params_dtype: torch.dtype = torch.float32
qkv_format: str = "sbhd" # "thd". NOTE: if we use context parallelism, we need to use "thd"
apply_rope_fusion: bool = True
bias_activation_fusion: bool = True
# these attributes are unused for images/videos, we just set because bridge training requires for LLMs
seq_length: int = 1024
share_embeddings_and_output_weights: bool = False
Expand Down
4 changes: 1 addition & 3 deletions dfm/src/megatron/model/wan/wan_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@


def wan_data_step(qkv_format, dataloader_iter):
batch = next(iter(dataloader_iter.iterable))

batch = next(dataloader_iter)
batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()}

# Construct packed sequence parameters
if ("seq_len_q" in batch) and ("seq_len_kv" in batch):
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
Expand Down
2 changes: 1 addition & 1 deletion dfm/src/megatron/recipes/wan/wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def pretrain_config(
context_embeddings_dim=4096,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
num_workers=10,
num_workers=16,
packing_buffer_size=None,
)
else:
Expand Down
44 changes: 44 additions & 0 deletions examples/megatron/override_configs/wan_pretrain_sample_data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# WAN Pretrain Mock Data Test Configuration
# Converted from L2_Function_Tests_GPU_Wan_Mock_Data.sh

model:
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
context_parallel_size: 1
crossattn_emb_size: 1536
hidden_size: 1536
ffn_hidden_size: 8960
num_attention_heads: 12
num_layers: 3
qkv_format: thd
seq_length: 2048

train:
eval_iters: 0
train_iters: 10
global_batch_size: 2
micro_batch_size: 1

optimizer:
lr: 5.0e-6
min_lr: 5.0e-6

scheduler:
lr_decay_style: constant
lr_warmup_iters: 0

checkpoint:
save: ${oc.env:CHECKPOINT_DIR,null}
load: ${oc.env:CHECKPOINT_DIR,null}
load_optim: false
save_interval: 200

dataset:
path: ${oc.env:DATASET_PATH,null}
seq_length: 2048
global_batch_size: 2
micro_batch_size: 1
packing_buffer_size: 50

logger:
log_interval: 1
33 changes: 33 additions & 0 deletions examples/megatron/recipes/wan/conf/gb200_perf_pretrain_mock.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
model:
tensor_model_parallel_size: 1
sequence_parallel: false
pipeline_model_parallel_size: 1
context_parallel_size: 4
crossattn_emb_size: 5120
hidden_size: 5120
ffn_hidden_size: 13824
num_attention_heads: 40
num_layers: 40
qkv_format: thd
seq_length: 2048 # This is not used

train:
global_batch_size: 64
micro_batch_size: 1
eval_iters: 0

scheduler:
lr_decay_style: constant
lr_warmup_iters: 0

optimizer:
lr: 5e-6
min_lr: 5e-6

dataset:
seq_length: 2048 # This is not used
global_batch_size: 64
micro_batch_size: 1

logger:
log_interval: 1
33 changes: 33 additions & 0 deletions examples/megatron/recipes/wan/conf/gb300_perf_pretrain_mock.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
model:
tensor_model_parallel_size: 1
sequence_parallel: false
pipeline_model_parallel_size: 1
context_parallel_size: 2
crossattn_emb_size: 5120
hidden_size: 5120
ffn_hidden_size: 13824
num_attention_heads: 40
num_layers: 40
qkv_format: thd
seq_length: 2048 # This is not used

train:
global_batch_size: 64
micro_batch_size: 1
eval_iters: 0

scheduler:
lr_decay_style: constant
lr_warmup_iters: 0

optimizer:
lr: 5e-6
min_lr: 5e-6

dataset:
seq_length: 2048 # This is not used
global_batch_size: 64
micro_batch_size: 1

logger:
log_interval: 1
Loading