|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +from collections import OrderedDict |
| 16 | +from types import SimpleNamespace |
| 17 | + |
15 | 18 | from megatron.bridge.data.loaders import build_train_valid_test_datasets |
16 | 19 | from megatron.bridge.data.samplers import ( |
17 | 20 | RandomSeedDataset, |
|
21 | 24 | from megatron.bridge.recipes.llama.llama3 import llama3_8b_pretrain_config as pretrain_config |
22 | 25 |
|
23 | 26 |
|
| 27 | +def _mock_tokenizer(): |
| 28 | + """Create a lightweight mock tokenizer for MockGPTLowLevelDataset. |
| 29 | +
|
| 30 | + MockGPTLowLevelDataset requires ``tokenizer.vocab_size`` and |
| 31 | + ``tokenizer.eod`` when building mock datasets. |
| 32 | + """ |
| 33 | + return SimpleNamespace( |
| 34 | + vocab_size=128256, |
| 35 | + eod=0, |
| 36 | + unique_identifiers=OrderedDict({"class": "MockTokenizer"}), |
| 37 | + ) |
| 38 | + |
| 39 | + |
24 | 40 | class TestDataSamplers: |
25 | 41 | def test_build_pretraining_data_loader(self): |
26 | 42 | dataloader = build_pretraining_data_loader( |
@@ -49,6 +65,7 @@ def to_megatron_provider(self, load_weights=False): |
49 | 65 | mock_from.return_value = _DummyBridge() |
50 | 66 | cfg = pretrain_config() |
51 | 67 | cfg.train.train_iters = 1000 |
| 68 | + cfg.dataset.tokenizer = _mock_tokenizer() |
52 | 69 | cfg.dataset.finalize() |
53 | 70 | dataset_provider = get_dataset_provider(cfg.dataset) |
54 | 71 | dataset = build_train_valid_test_datasets(cfg=cfg, build_train_valid_test_datasets_provider=dataset_provider) |
@@ -92,6 +109,7 @@ def to_megatron_provider(self, load_weights=False): |
92 | 109 | mock_from.return_value = _DummyBridge() |
93 | 110 | cfg = pretrain_config() |
94 | 111 | cfg.train.train_iters = 1000 |
| 112 | + cfg.dataset.tokenizer = _mock_tokenizer() |
95 | 113 | cfg.dataset.finalize() |
96 | 114 | dataset_provider = get_dataset_provider(cfg.dataset) |
97 | 115 | dataset = build_train_valid_test_datasets(cfg=cfg, build_train_valid_test_datasets_provider=dataset_provider) |
@@ -144,6 +162,7 @@ def to_megatron_provider(self, load_weights=False): |
144 | 162 | mock_from.return_value = _DummyBridge() |
145 | 163 | cfg = pretrain_config() |
146 | 164 | cfg.train.train_iters = 1000 |
| 165 | + cfg.dataset.tokenizer = _mock_tokenizer() |
147 | 166 | cfg.dataset.finalize() |
148 | 167 | dataset_provider = get_dataset_provider(cfg.dataset) |
149 | 168 | dataset = build_train_valid_test_datasets(cfg=cfg, build_train_valid_test_datasets_provider=dataset_provider) |
@@ -568,6 +587,7 @@ def to_megatron_provider(self, load_weights=False): |
568 | 587 | cfg = pretrain_config() |
569 | 588 | cfg.train.train_iters = 1000 |
570 | 589 | cfg.train.global_batch_size = 16 |
| 590 | + cfg.dataset.tokenizer = _mock_tokenizer() |
571 | 591 | cfg.dataset.finalize() |
572 | 592 | dataset_provider = get_dataset_provider(cfg.dataset) |
573 | 593 | dataset = build_train_valid_test_datasets(cfg=cfg, build_train_valid_test_datasets_provider=dataset_provider) |
@@ -604,6 +624,7 @@ def to_megatron_provider(self, load_weights=False): |
604 | 624 | mock_from.return_value = _DummyBridge() |
605 | 625 | cfg = pretrain_config() |
606 | 626 | cfg.train.train_iters = 1000 |
| 627 | + cfg.dataset.tokenizer = _mock_tokenizer() |
607 | 628 | cfg.dataset.finalize() |
608 | 629 | dataset_provider = get_dataset_provider(cfg.dataset) |
609 | 630 | dataset = build_train_valid_test_datasets(cfg=cfg, build_train_valid_test_datasets_provider=dataset_provider) |
|
0 commit comments