Skip to content

Commit 55f8cf8

Browse files
authored
Fixed MLM dataset arguments(#290)
1 parent 9d26431 commit 55f8cf8

File tree

1 file changed

+39
-17
lines changed

1 file changed

+39
-17
lines changed

megatron/data/mlm_dataset.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,20 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
7676
return (blending_train_dataset, blending_valid_dataset,
7777
blending_test_dataset)
7878

79-
def build_dataset_group(dataset_group_name, paths, weights, splits, data_impl,
80-
train_valid_test_num_samples,
81-
seq_length, seed, skip_warmup, train_valid_test):
79+
def build_dataset_group(
80+
dataset_group_name,
81+
paths,
82+
weights,
83+
splits,
84+
data_impl,
85+
train_valid_test_num_samples,
86+
seq_length,
87+
noise_density,
88+
mean_noise_span_length,
89+
seed,
90+
skip_warmup,
91+
train_valid_test
92+
):
8293
'''
8394
Build a single dataset group corresponding to Option 2 of data loading see arguments.py
8495
a dataset group is passed on the following form
@@ -91,12 +102,18 @@ def build_dataset_group(dataset_group_name, paths, weights, splits, data_impl,
91102

92103
# Single dataset.
93104
if len(paths) == 1:
94-
dataset = _build_single_datasets(paths[0],
95-
splits[0],
96-
data_impl,
97-
train_valid_test_num_samples,
98-
seq_length, seed, skip_warmup,
99-
dataset_group_name, train_valid_test)
105+
dataset = _build_single_datasets(
106+
data_prefix=paths[0],
107+
range_string=splits[0],
108+
data_impl=data_impl,
109+
train_valid_test_num_samples=train_valid_test_num_samples,
110+
sequence_length=seq_length,
111+
noise_density=noise_density,
112+
mean_noise_span_length=mean_noise_span_length,
113+
seed=seed,
114+
skip_warmup=skip_warmup,
115+
dataset_group_name=dataset_group_name,
116+
train_valid_test=train_valid_test)
100117
return dataset
101118
# Blending dataset.
102119
else:
@@ -114,14 +131,19 @@ def build_dataset_group(dataset_group_name, paths, weights, splits, data_impl,
114131
# Build individual datasets.
115132
datasets = []
116133
for i in range(len(prefixes)):
117-
ds = _build_single_datasets(prefixes[i],
118-
splits[i],
119-
data_impl,
120-
datasets_train_valid_test_num_samples[i],
121-
seq_length,
122-
seed, skip_warmup,
123-
dataset_group_name, train_valid_test)
124-
134+
ds = _build_single_datasets(
135+
data_prefix=prefixes[i],
136+
range_string=splits[i],
137+
data_impl=data_impl,
138+
train_valid_test_num_samples=datasets_train_valid_test_num_samples[i],
139+
sequence_length=seq_length,
140+
noise_density=noise_density,
141+
mean_noise_span_length=mean_noise_span_length,
142+
seed=seed,
143+
skip_warmup=skip_warmup,
144+
dataset_group_name=dataset_group_name,
145+
train_valid_test=train_valid_test
146+
)
125147
datasets.append(ds)
126148
all_datasets = BlendableDataset(datasets, weights)
127149

0 commit comments

Comments
 (0)