diff --git a/dfm/src/megatron/data/dit/dit_mock_datamodule.py b/dfm/src/megatron/data/dit/dit_mock_datamodule.py index 4011205e..3d5e9cf3 100644 --- a/dfm/src/megatron/data/dit/dit_mock_datamodule.py +++ b/dfm/src/megatron/data/dit/dit_mock_datamodule.py @@ -158,6 +158,7 @@ def __post_init__(self): shuffle=False, drop_last=False, ) + self._train_dl = iter(self._train_dl) self.sequence_length = self.seq_length def build_datasets(self, _context: DatasetBuildContext): diff --git a/dfm/src/megatron/data/wan/wan_mock_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_datamodule.py index 844ecd9b..8837ea4d 100644 --- a/dfm/src/megatron/data/wan/wan_mock_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_datamodule.py @@ -136,6 +136,7 @@ def __post_init__(self): shuffle=False, drop_last=False, ) + self._train_dl = iter(self._train_dl) self.sequence_length = self.seq_length def build_datasets(self, _context: DatasetBuildContext): diff --git a/tests/functional_tests/mcore/recipes/test_dit_pretrain.py b/tests/functional_tests/mcore/recipes/test_dit_pretrain.py index e5ca6ff4..2f094431 100644 --- a/tests/functional_tests/mcore/recipes/test_dit_pretrain.py +++ b/tests/functional_tests/mcore/recipes/test_dit_pretrain.py @@ -55,6 +55,8 @@ def test_DiT_pretrain_mock(self, tmp_path): "model.context_parallel_size=1", "model.qkv_format=thd", "model.num_attention_heads=16", + f"checkpoint.save={checkpoint_dir}", + f"checkpoint.load={checkpoint_dir}", "dataset.task_encoder_seq_length=4608", "dataset.seq_length=4608", "train.global_batch_size=2", diff --git a/tests/unit_tests/megatron/data/wan/test_wan_mock_datamodule.py b/tests/unit_tests/megatron/data/wan/test_wan_mock_datamodule.py index e1980052..ede3494f 100644 --- a/tests/unit_tests/megatron/data/wan/test_wan_mock_datamodule.py +++ b/tests/unit_tests/megatron/data/wan/test_wan_mock_datamodule.py @@ -13,7 +13,6 @@ # limitations under the License. import torch -from torch.utils.data import DataLoader from dfm.src.megatron.data.wan.wan_mock_datamodule import WanMockDataModuleConfig @@ -37,7 +36,6 @@ def test_wan_mock_datamodule_build_and_batch_shapes(): context_embeddings_dim=64, ) train_dl, val_dl, test_dl = cfg.build_datasets(_context=None) - assert isinstance(train_dl, DataLoader) assert train_dl is val_dl and val_dl is test_dl batch = next(iter(train_dl))