-
Notifications
You must be signed in to change notification settings - Fork 4
Supporting Wan model #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 60 commits
Commits
Show all changes
61 commits
Select commit
Hold shift + click to select a range
75ed5b2
first commit
2ebfd50
workable code
7b834f0
workable thd
2152abd
clean up, remove all CP for sbhd, CP now is only for thd
389a037
run outside of Mbridge
daac350
Update example scripts and add new data module for multimodal datasets
abhinavg4 d5d0106
workable code before refactoring
c4f5160
Merge remote-tracking branch 'origin/huvu/mcore_wan' into huvu/mcore_wan
0430384
refactor attention submodules + reorder files locations
dfff86b
update refactor
abbaa2a
update refactor
c59f6a2
reorganize files
0b91a1c
reorganize files
aa20504
refactoring code
d5f58c9
add README for perf test
9b8e4fb
using vae, t5, scheduler from Diffusers
7f414ae
update repo, remove Wan's Github moduels
62a518f
Merge remote-tracking branch 'origin/main' into huvu/mcore_wan
2de5934
fix Ruff
6b46a7f
fix ruff + copyright
c1d8923
fix Ruff + Lint
e8de1ae
fix Ruff + Lint
287ad34
fix Ruff + Lint
4464fd2
fix Ruff + Lint
547339a
fix Ruff + Lint
9cd082b
fix Ruff + Lint
4514eee
fix Ruff + Lint
acd430d
fix Ruff + Lint
19c0c29
Merge remote-tracking branch 'origin/main' into huvu/mcore_wan
a147258
merged main + address comments
f3828b0
remove example_commands.md, Google waits until mid Nov
4727447
refactor inference_configs + mockdatamodule
8f49e23
add dit_embeddings.py
4766b1b
fix lint ruff
c4004ea
add 'average_gradients_across_tp_domain' to torch.nn for when running…
c14001d
merge from main
e332cb2
add english negative prompt
bc03727
fix ruff lint
d7c1acb
Update uv.lock for deps: diffusers==0.35.1, easydict, imageio
c525013
update dfm/src/megatron/data/dit
0f57585
change english negative prompt
d17286d
seem to workable seq_packing
e936907
refactor with Sajad's PR - DiT data to common dir
66796b5
fix Ruff, lint
7d8e64f
fix Ruff, lint
6263299
fix Ruff, lint
377ff5b
workable mock datamodule (doesn't need setting path); updated trainin…
0ca76a8
merge main
d8550c4
bring wan_task encoders features to common, sharing with dit
a13d0c0
lint, ruff
39b0e73
lint, ruff
4647d89
lint, ruff
174bb7b
fix CP error (input of thd_split_inputs_cp to be cu_seqlens_q_padded …
462638a
udpate README_perf_test.md
f5c10a1
fix lint, ruff
0b0058f
update uv.lock, merge main
13968fc
update uv.lock, merge main
46aa6d8
uv.lock
6b553ec
uv.lock
b1c41fc
uv.lock
681145b
update uv.lock [using ci]
pablo-garay File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # pylint: disable=C0115,C0116,C0301 | ||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
| from megatron.bridge.data.utils import DatasetBuildContext | ||
| from torch import int_repr | ||
|
|
||
| from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModule, DiffusionDataModuleConfig | ||
| from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder | ||
|
|
||
|
|
||
| @dataclass(kw_only=True) | ||
| class WanDataModuleConfig(DiffusionDataModuleConfig): | ||
| path: str | ||
| seq_length: int | ||
| packing_buffer_size: int | ||
| micro_batch_size: int | ||
| global_batch_size: int | ||
| num_workers: int_repr | ||
| dataloader_type: str = "external" | ||
|
|
||
| def __post_init__(self): | ||
| self.dataset = DiffusionDataModule( | ||
| path=self.path, | ||
| seq_length=self.seq_length, | ||
| packing_buffer_size=self.packing_buffer_size, | ||
| task_encoder=WanTaskEncoder(seq_length=self.seq_length, packing_buffer_size=self.packing_buffer_size), | ||
| micro_batch_size=self.micro_batch_size, | ||
| global_batch_size=self.global_batch_size, | ||
| num_workers=self.num_workers, | ||
| ) | ||
| self.sequence_length = self.dataset.seq_length | ||
|
|
||
| def build_datasets(self, context: DatasetBuildContext): | ||
| return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # pylint: disable=C0115,C0116,C0301 | ||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
| import torch | ||
| from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider | ||
| from torch.utils.data import DataLoader, Dataset | ||
|
|
||
| from dfm.src.megatron.model.wan.utils import patchify | ||
|
|
||
|
|
||
| class _MockDataset(Dataset): | ||
| def __init__(self, length: int): | ||
| self.length = max(int(length), 1) | ||
|
|
||
| def __len__(self) -> int: | ||
| return self.length | ||
|
|
||
| def __getitem__(self, idx: int) -> dict: | ||
| return {} | ||
|
|
||
|
|
||
| def mock_batch( | ||
| F_latents: int, | ||
| H_latents: int, | ||
| W_latents: int, | ||
| patch_temporal: int, | ||
| patch_spatial: int, | ||
| number_packed_samples: int, | ||
| context_seq_len: int, | ||
| context_embeddings_dim: int, | ||
| ) -> dict: | ||
| # set mock values for one video sample | ||
| video_latent = torch.randn(16, F_latents, H_latents, W_latents, dtype=torch.float32) | ||
| grid_size = torch.tensor( | ||
| [ | ||
| video_latent.shape[1] // patch_temporal, | ||
| video_latent.shape[2] // patch_spatial, | ||
| video_latent.shape[3] // patch_spatial, | ||
| ], | ||
| dtype=torch.int32, | ||
| ) | ||
| video_latent = patchify([video_latent], (patch_temporal, patch_spatial, patch_spatial))[0] | ||
| video_latent = torch.as_tensor(video_latent, dtype=torch.float32) | ||
| seq_len_q = video_latent.shape[0] | ||
| seq_len_q_padded = seq_len_q | ||
| loss_mask = torch.ones(seq_len_q, dtype=torch.bfloat16) | ||
| context_embeddings = torch.randn(context_seq_len, context_embeddings_dim, dtype=torch.float32) | ||
| seq_len_kv = context_embeddings.shape[0] | ||
| seq_len_kv_padded = seq_len_kv | ||
| video_metadata = {} | ||
|
|
||
| # set mock values for packed video samples | ||
| video_latents_packed = [video_latent for _ in range(number_packed_samples)] | ||
| video_latents_packed = torch.cat(video_latents_packed, dim=0) | ||
| loss_masks_packed = [loss_mask for _ in range(number_packed_samples)] | ||
| loss_masks_packed = torch.cat(loss_masks_packed, dim=0) | ||
| seq_len_q_packed = torch.tensor([seq_len_q for _ in range(number_packed_samples)], dtype=torch.int32) | ||
| seq_len_q_padded_packed = torch.tensor([seq_len_q_padded for _ in range(number_packed_samples)], dtype=torch.int32) | ||
| seq_len_kv_packed = torch.tensor([seq_len_kv for _ in range(number_packed_samples)], dtype=torch.int32) | ||
| seq_len_kv_padded_packed = torch.tensor( | ||
| [seq_len_kv_padded for _ in range(number_packed_samples)], dtype=torch.int32 | ||
| ) | ||
| grid_sizes_packed = torch.stack([grid_size for _ in range(number_packed_samples)], dim=0) | ||
| context_embeddings_packed = [context_embeddings for _ in range(number_packed_samples)] | ||
| context_embeddings_packed = torch.cat(context_embeddings_packed, dim=0) | ||
|
|
||
| ### Note: shape of sample's values | ||
| # video_latent: [num_patches, latents_channels * pF * pH * pW] | ||
| # grid_size: [F_patches, W_patches, H_patches] | ||
| # context_embeddings: [context_seq_len, text_embedding_dim] | ||
|
|
||
| batch = dict( | ||
| video_latents=video_latents_packed.unsqueeze(1), | ||
| context_embeddings=context_embeddings_packed.unsqueeze(1), | ||
| loss_mask=loss_masks_packed.unsqueeze(1), | ||
| seq_len_q=seq_len_q_packed, | ||
| seq_len_q_padded=seq_len_q_padded_packed, | ||
| seq_len_kv=seq_len_kv_packed, | ||
| seq_len_kv_padded=seq_len_kv_padded_packed, | ||
| grid_sizes=grid_sizes_packed, | ||
| video_metadata=video_metadata, | ||
| ) | ||
|
|
||
| return batch | ||
|
|
||
|
|
||
| @dataclass(kw_only=True) | ||
| class WanMockDataModuleConfig(DatasetProvider): | ||
| path: str = "" | ||
| seq_length: int | ||
| packing_buffer_size: int | ||
| micro_batch_size: int | ||
| global_batch_size: int | ||
| num_workers: int | ||
| dataloader_type: str = "external" | ||
| F_latents: int = 24 | ||
| H_latents: int = 104 | ||
| W_latents: int = 60 | ||
| patch_spatial: int = 2 | ||
| patch_temporal: int = 1 | ||
| number_packed_samples: int = 3 | ||
| context_seq_len: int = 512 | ||
| context_embeddings_dim: int = 4096 | ||
|
|
||
| def __post_init__(self): | ||
| mock_ds = _MockDataset(length=1024) | ||
| self._train_dl = DataLoader( | ||
| mock_ds, | ||
| batch_size=self.micro_batch_size, | ||
| num_workers=self.num_workers, | ||
| collate_fn=lambda samples: mock_batch( | ||
| F_latents=self.F_latents, | ||
| H_latents=self.H_latents, | ||
| W_latents=self.W_latents, | ||
| patch_temporal=self.patch_temporal, | ||
| patch_spatial=self.patch_spatial, | ||
| number_packed_samples=self.number_packed_samples, | ||
| context_seq_len=self.context_seq_len, | ||
| context_embeddings_dim=self.context_embeddings_dim, | ||
| ), | ||
| shuffle=False, | ||
| drop_last=False, | ||
| ) | ||
| self.sequence_length = self.seq_length | ||
|
|
||
| def build_datasets(self, _context: DatasetBuildContext): | ||
| if hasattr(self, "dataset"): | ||
| return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() | ||
| return self._train_dl, self._train_dl, self._train_dl | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.