Skip to content
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
75ed5b2
first commit
Oct 30, 2025
2ebfd50
workable code
Oct 30, 2025
7b834f0
workable thd
Oct 31, 2025
2152abd
clean up, remove all CP for sbhd, CP now is only for thd
Oct 31, 2025
389a037
run outside of Mbridge
Oct 31, 2025
daac350
Update example scripts and add new data module for multimodal datasets
abhinavg4 Nov 3, 2025
d5d0106
workable code before refactoring
Nov 3, 2025
c4f5160
Merge remote-tracking branch 'origin/huvu/mcore_wan' into huvu/mcore_wan
Nov 3, 2025
0430384
refactor attention submodules + reorder files locations
Nov 4, 2025
dfff86b
update refactor
Nov 4, 2025
abbaa2a
update refactor
Nov 4, 2025
c59f6a2
reorganize files
Nov 4, 2025
0b91a1c
reorganize files
Nov 4, 2025
aa20504
refactoring code
Nov 5, 2025
d5f58c9
add README for perf test
Nov 5, 2025
9b8e4fb
using vae, t5, scheduler from Diffusers
Nov 5, 2025
7f414ae
update repo, remove Wan's Github moduels
Nov 5, 2025
62a518f
Merge remote-tracking branch 'origin/main' into huvu/mcore_wan
Nov 6, 2025
2de5934
fix Ruff
Nov 6, 2025
6b46a7f
fix ruff + copyright
Nov 6, 2025
c1d8923
fix Ruff + Lint
Nov 6, 2025
e8de1ae
fix Ruff + Lint
Nov 6, 2025
287ad34
fix Ruff + Lint
Nov 6, 2025
4464fd2
fix Ruff + Lint
Nov 6, 2025
547339a
fix Ruff + Lint
Nov 6, 2025
9cd082b
fix Ruff + Lint
Nov 6, 2025
4514eee
fix Ruff + Lint
Nov 6, 2025
acd430d
fix Ruff + Lint
Nov 6, 2025
19c0c29
Merge remote-tracking branch 'origin/main' into huvu/mcore_wan
Nov 6, 2025
a147258
merged main + address comments
Nov 6, 2025
f3828b0
remove example_commands.md, Google waits until mid Nov
Nov 6, 2025
4727447
refactor inference_configs + mockdatamodule
Nov 6, 2025
8f49e23
add dit_embeddings.py
Nov 6, 2025
4766b1b
fix lint ruff
Nov 6, 2025
c4004ea
add 'average_gradients_across_tp_domain' to torch.nn for when running…
Nov 7, 2025
c14001d
merge from main
Nov 7, 2025
e332cb2
add english negative prompt
Nov 7, 2025
bc03727
fix ruff lint
Nov 7, 2025
d7c1acb
Update uv.lock for deps: diffusers==0.35.1, easydict, imageio
Nov 7, 2025
c525013
update dfm/src/megatron/data/dit
Nov 7, 2025
0f57585
change english negative prompt
Nov 8, 2025
d17286d
seem to workable seq_packing
Nov 10, 2025
e936907
refactor with Sajad's PR - DiT data to common dir
Nov 10, 2025
66796b5
fix Ruff, lint
Nov 10, 2025
7d8e64f
fix Ruff, lint
Nov 10, 2025
6263299
fix Ruff, lint
Nov 10, 2025
377ff5b
workable mock datamodule (doesn't need setting path); updated trainin…
Nov 11, 2025
0ca76a8
merge main
Nov 11, 2025
d8550c4
bring wan_task encoders features to common, sharing with dit
Nov 11, 2025
a13d0c0
lint, ruff
Nov 11, 2025
39b0e73
lint, ruff
Nov 11, 2025
4647d89
lint, ruff
Nov 11, 2025
174bb7b
fix CP error (input of thd_split_inputs_cp to be cu_seqlens_q_padded …
Nov 12, 2025
462638a
udpate README_perf_test.md
Nov 12, 2025
f5c10a1
fix lint, ruff
Nov 12, 2025
0b0058f
update uv.lock, merge main
Nov 12, 2025
13968fc
update uv.lock, merge main
Nov 12, 2025
46aa6d8
uv.lock
Nov 12, 2025
6b553ec
uv.lock
Nov 12, 2025
b1c41fc
uv.lock
Nov 12, 2025
681145b
update uv.lock [using ci]
pablo-garay Nov 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions dfm/src/megatron/data/common/diffusion_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ class DiffusionSample(Sample):
num_frames (Optional[torch.Tensor]): Number of frames in the video.
padding_mask (Optional[torch.Tensor]): Mask indicating padding positions.
seq_len_q (Optional[torch.Tensor]): Sequence length for query embeddings.
seq_len_q_padded (Optional[torch.Tensor]): Sequence length for query embeddings after padding.
seq_len_kv (Optional[torch.Tensor]): Sequence length for key/value embeddings.
pos_ids (Optional[torch.Tensor]): Positional IDs.
latent_shape (Optional[torch.Tensor]): Shape of the latent tensor.
video_metadata (Optional[dict]): Metadata of the video.
"""

video: torch.Tensor # video latents (C T H W)
Expand All @@ -48,9 +50,12 @@ class DiffusionSample(Sample):
num_frames: Optional[torch.Tensor] = None
padding_mask: Optional[torch.Tensor] = None
seq_len_q: Optional[torch.Tensor] = None
seq_len_q_padded: Optional[torch.Tensor] = None
seq_len_kv: Optional[torch.Tensor] = None
seq_len_kv_padded: Optional[torch.Tensor] = None
pos_ids: Optional[torch.Tensor] = None
latent_shape: Optional[torch.Tensor] = None
video_metadata: Optional[dict] = None

def to_dict(self) -> dict:
"""Converts the sample to a dictionary."""
Expand All @@ -64,9 +69,12 @@ def to_dict(self) -> dict:
num_frames=self.num_frames,
padding_mask=self.padding_mask,
seq_len_q=self.seq_len_q,
seq_len_q_padded=self.seq_len_q_padded,
seq_len_kv=self.seq_len_kv,
seq_len_kv_padded=self.seq_len_kv_padded,
pos_ids=self.pos_ids,
latent_shape=self.latent_shape,
video_metadata=self.video_metadata,
)

def __add__(self, other: Any) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import random
from abc import ABC, abstractmethod
from typing import List
Expand Down Expand Up @@ -103,9 +102,12 @@ def cat(attr):
context_embeddings=cat("context_embeddings"),
loss_mask=cat("loss_mask"),
seq_len_q=cat("seq_len_q"),
seq_len_q_padded=cat("seq_len_q_padded"),
seq_len_kv=cat("seq_len_kv"),
seq_len_kv_padded=cat("seq_len_kv_padded"),
pos_ids=cat("pos_ids"),
latent_shape=stack("latent_shape"),
video_metadata=[sample.video_metadata for sample in samples],
)

@stateless
Expand Down
49 changes: 49 additions & 0 deletions dfm/src/megatron/data/wan/wan_energon_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=C0115,C0116,C0301

from dataclasses import dataclass

from megatron.bridge.data.utils import DatasetBuildContext
from torch import int_repr

from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModule, DiffusionDataModuleConfig
from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder


@dataclass(kw_only=True)
class WanDataModuleConfig(DiffusionDataModuleConfig):
path: str
seq_length: int
packing_buffer_size: int
micro_batch_size: int
global_batch_size: int
num_workers: int_repr
dataloader_type: str = "external"

def __post_init__(self):
self.dataset = DiffusionDataModule(
path=self.path,
seq_length=self.seq_length,
packing_buffer_size=self.packing_buffer_size,
task_encoder=WanTaskEncoder(seq_length=self.seq_length, packing_buffer_size=self.packing_buffer_size),
micro_batch_size=self.micro_batch_size,
global_batch_size=self.global_batch_size,
num_workers=self.num_workers,
)
self.sequence_length = self.dataset.seq_length

def build_datasets(self, context: DatasetBuildContext):
return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader()
144 changes: 144 additions & 0 deletions dfm/src/megatron/data/wan/wan_mock_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=C0115,C0116,C0301

from dataclasses import dataclass

import torch
from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider
from torch.utils.data import DataLoader, Dataset

from dfm.src.megatron.model.wan.utils import patchify


class _MockDataset(Dataset):
def __init__(self, length: int):
self.length = max(int(length), 1)

def __len__(self) -> int:
return self.length

def __getitem__(self, idx: int) -> dict:
return {}


def mock_batch(
F_latents: int,
H_latents: int,
W_latents: int,
patch_temporal: int,
patch_spatial: int,
number_packed_samples: int,
context_seq_len: int,
context_embeddings_dim: int,
) -> dict:
# set mock values for one video sample
video_latent = torch.randn(16, F_latents, H_latents, W_latents, dtype=torch.float32)
grid_size = torch.tensor(
[
video_latent.shape[1] // patch_temporal,
video_latent.shape[2] // patch_spatial,
video_latent.shape[3] // patch_spatial,
],
dtype=torch.int32,
)
video_latent = patchify([video_latent], (patch_temporal, patch_spatial, patch_spatial))[0]
video_latent = torch.as_tensor(video_latent, dtype=torch.float32)
seq_len_q = video_latent.shape[0]
seq_len_q_padded = seq_len_q
loss_mask = torch.ones(seq_len_q, dtype=torch.bfloat16)
context_embeddings = torch.randn(context_seq_len, context_embeddings_dim, dtype=torch.float32)
seq_len_kv = context_embeddings.shape[0]
seq_len_kv_padded = seq_len_kv
video_metadata = {}

# set mock values for packed video samples
video_latents_packed = [video_latent for _ in range(number_packed_samples)]
video_latents_packed = torch.cat(video_latents_packed, dim=0)
loss_masks_packed = [loss_mask for _ in range(number_packed_samples)]
loss_masks_packed = torch.cat(loss_masks_packed, dim=0)
seq_len_q_packed = torch.tensor([seq_len_q for _ in range(number_packed_samples)], dtype=torch.int32)
seq_len_q_padded_packed = torch.tensor([seq_len_q_padded for _ in range(number_packed_samples)], dtype=torch.int32)
seq_len_kv_packed = torch.tensor([seq_len_kv for _ in range(number_packed_samples)], dtype=torch.int32)
seq_len_kv_padded_packed = torch.tensor(
[seq_len_kv_padded for _ in range(number_packed_samples)], dtype=torch.int32
)
grid_sizes_packed = torch.stack([grid_size for _ in range(number_packed_samples)], dim=0)
context_embeddings_packed = [context_embeddings for _ in range(number_packed_samples)]
context_embeddings_packed = torch.cat(context_embeddings_packed, dim=0)

### Note: shape of sample's values
# video_latent: [num_patches, latents_channels * pF * pH * pW]
# grid_size: [F_patches, W_patches, H_patches]
# context_embeddings: [context_seq_len, text_embedding_dim]

batch = dict(
video_latents=video_latents_packed.unsqueeze(1),
context_embeddings=context_embeddings_packed.unsqueeze(1),
loss_mask=loss_masks_packed.unsqueeze(1),
seq_len_q=seq_len_q_packed,
seq_len_q_padded=seq_len_q_padded_packed,
seq_len_kv=seq_len_kv_packed,
seq_len_kv_padded=seq_len_kv_padded_packed,
grid_sizes=grid_sizes_packed,
video_metadata=video_metadata,
)

return batch


@dataclass(kw_only=True)
class WanMockDataModuleConfig(DatasetProvider):
path: str = ""
seq_length: int
packing_buffer_size: int
micro_batch_size: int
global_batch_size: int
num_workers: int
dataloader_type: str = "external"
F_latents: int = 24
H_latents: int = 104
W_latents: int = 60
patch_spatial: int = 2
patch_temporal: int = 1
number_packed_samples: int = 3
context_seq_len: int = 512
context_embeddings_dim: int = 4096

def __post_init__(self):
mock_ds = _MockDataset(length=1024)
self._train_dl = DataLoader(
mock_ds,
batch_size=self.micro_batch_size,
num_workers=self.num_workers,
collate_fn=lambda samples: mock_batch(
F_latents=self.F_latents,
H_latents=self.H_latents,
W_latents=self.W_latents,
patch_temporal=self.patch_temporal,
patch_spatial=self.patch_spatial,
number_packed_samples=self.number_packed_samples,
context_seq_len=self.context_seq_len,
context_embeddings_dim=self.context_embeddings_dim,
),
shuffle=False,
drop_last=False,
)
self.sequence_length = self.seq_length

def build_datasets(self, _context: DatasetBuildContext):
if hasattr(self, "dataset"):
return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader()
return self._train_dl, self._train_dl, self._train_dl
Loading
Loading