Skip to content

Commit 47bdad5

Browse files
authored
Enhance DiT and Wan mock data modules by initializing training data l… (#65)
* Enhance DiT and Wan mock data modules by initializing training data loaders as iterators. Update test configuration for DiT pretraining to include checkpoint save and load paths. * Remove unnecessary DataLoader assertion from Wan mock datamodule tests to streamline test execution.
1 parent f824d18 commit 47bdad5

File tree

4 files changed

+4
-2
lines changed

4 files changed

+4
-2
lines changed

dfm/src/megatron/data/dit/dit_mock_datamodule.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def __post_init__(self):
158158
shuffle=False,
159159
drop_last=False,
160160
)
161+
self._train_dl = iter(self._train_dl)
161162
self.sequence_length = self.seq_length
162163

163164
def build_datasets(self, _context: DatasetBuildContext):

dfm/src/megatron/data/wan/wan_mock_datamodule.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def __post_init__(self):
136136
shuffle=False,
137137
drop_last=False,
138138
)
139+
self._train_dl = iter(self._train_dl)
139140
self.sequence_length = self.seq_length
140141

141142
def build_datasets(self, _context: DatasetBuildContext):

tests/functional_tests/mcore/recipes/test_dit_pretrain.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def test_DiT_pretrain_mock(self, tmp_path):
5555
"model.context_parallel_size=1",
5656
"model.qkv_format=thd",
5757
"model.num_attention_heads=16",
58+
f"checkpoint.save={checkpoint_dir}",
59+
f"checkpoint.load={checkpoint_dir}",
5860
"dataset.task_encoder_seq_length=4608",
5961
"dataset.seq_length=4608",
6062
"train.global_batch_size=2",

tests/unit_tests/megatron/data/wan/test_wan_mock_datamodule.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import torch
16-
from torch.utils.data import DataLoader
1716

1817
from dfm.src.megatron.data.wan.wan_mock_datamodule import WanMockDataModuleConfig
1918

@@ -37,7 +36,6 @@ def test_wan_mock_datamodule_build_and_batch_shapes():
3736
context_embeddings_dim=64,
3837
)
3938
train_dl, val_dl, test_dl = cfg.build_datasets(_context=None)
40-
assert isinstance(train_dl, DataLoader)
4139
assert train_dl is val_dl and val_dl is test_dl
4240

4341
batch = next(iter(train_dl))

0 commit comments

Comments
 (0)