Skip to content

Commit 8c82f21

Browse files
committed
move data modules to common, move cosmos to common/tokenizer instad of common/models. Make inference compatible with the new data loading.
Signed-off-by: Sajad Norouzi <snorouzi@nvidia.com>
1 parent 0180c3e commit 8c82f21

File tree

18 files changed

+17
-109
lines changed

18 files changed

+17
-109
lines changed
File renamed without changes.
File renamed without changes.

dfm/src/common/models/cosmos/cosmos1/causal_video_tokenizer.py renamed to dfm/src/common/tokenizers/cosmos/cosmos1/causal_video_tokenizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from huggingface_hub import hf_hub_download
2626
from tqdm import tqdm
2727

28-
from dfm.src.common.models.cosmos.cosmos1.tokenizer_utils import (
28+
from dfm.src.common.tokenizers.cosmos.cosmos1.video_tokenizer_utils import (
2929
load_jit_model,
3030
numpy2tensor,
3131
pad_video_batch,

dfm/src/common/models/cosmos/cosmos1/tokenizer_utils.py renamed to dfm/src/common/tokenizers/cosmos/cosmos1/video_tokenizer_utils.py

File renamed without changes.

dfm/src/common/utils/save_video.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import imageio
1616
import numpy as np
17-
import torch
1817

1918

2019
def save_video(
@@ -35,12 +34,3 @@ def save_video(
3534

3635
print("video_save_path", video_save_path)
3736
imageio.mimsave(video_save_path, grid, "mp4", **kwargs)
38-
39-
40-
def print_dict(dict):
41-
for key, value in dict.items():
42-
if isinstance(value, torch.Tensor):
43-
print(key, value.shape)
44-
else:
45-
print(key, value)
46-
print("-" * 40)

dfm/src/megatron/data/common/__init__.py

Whitespace-only changes.

dfm/src/megatron/data/dit/base_energon_datamodule.py renamed to dfm/src/megatron/data/common/base_energon_datamodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
class EnergonMultiModalDataModule:
2626
"""
27-
A PyTorch Lightning DataModule for handling multimodal datasets with images and text.
27+
A DataModule for handling multimodal datasets with images and text.
2828
2929
This data module is designed to work with multimodal datasets that involve both images and text.
3030
It provides a seamless interface to load training and validation data, manage batching, and handle

dfm/src/megatron/data/dit/diffusion_energon_datamodule.py renamed to dfm/src/megatron/data/common/diffusion_energon_datamodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from megatron.energon import DefaultTaskEncoder, get_train_dataset
2424
from torch import int_repr
2525

26-
from dfm.src.megatron.data.dit.base_energon_datamodule import EnergonMultiModalDataModule
26+
from dfm.src.megatron.data.common.base_energon_datamodule import EnergonMultiModalDataModule
2727
from dfm.src.megatron.data.dit.dit_taskencoder import DiTTaskEncoder
2828

2929

File renamed without changes.

0 commit comments

Comments
 (0)