Skip to content

Commit 131bd43

Browse files
authored
Merge MLM too fast 2 (#294)
* Merge MLM too fast 2 * Update megatron/data/mlm_dataset.py
1 parent 3ab0ad1 commit 131bd43

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

megatron/arguments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -927,8 +927,8 @@ def __call__(self, parser, args, values, option_string=None):
927927
'specific positions. This option tries to un-bias the loss by reweighting loss on specific '
928928
'positions based on how frequently we train on that position.'
929929
'This is mostly used for prefix_lm training')
930-
group.add_argument("--noise_density", type=float, default=None, help="Span corruption noise density")
931-
group.add_argument("--mean_noise_span_length", type=int, default=None, help="Span corruption mean noise span length")
930+
group.add_argument("--noise-density", type=float, default=None, help="Span corruption noise density")
931+
group.add_argument("--mean-noise-span-length", type=int, default=None, help="Span corruption mean noise span length")
932932

933933

934934
return parser

megatron/data/mlm_dataset.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import torch
55

6-
from megatron import print_rank_0, get_tokenizer
6+
from megatron import print_rank_0, get_tokenizer, get_args
77
from megatron.data.blendable_dataset import BlendableDataset
88
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, get_split_by_range_
99
from megatron.data.dataset_utils import get_train_valid_test_split_, get_indexed_dataset_
@@ -296,14 +296,14 @@ def __init__(
296296
# To ensure that the input length is `sequence_length`, we need to increase the maximum length
297297
# according to `noise_density` and `mean_noise_span_length`. We can also define the label length accordingly.
298298
number_of_raw_tokens, inputs_length, targets_length, num_noise_spans = compute_input_and_target_lengths(
299-
# +1 is used so that we can compute the as autoregressive systems require us to add one more token.
300-
sequence_length=self.sequence_length + 1,
299+
sequence_length=self.sequence_length,
301300
noise_density=self.noise_density,
302301
mean_noise_span_length=self.mean_noise_span_length
303302
)
304-
self.number_of_raw_tokens = number_of_raw_tokens
305303
self.inputs_length = inputs_length
306-
self.targets_length = targets_length
304+
# In order to compute loss, we need an extra token at the end.
305+
self.number_of_raw_tokens = number_of_raw_tokens + 1
306+
self.targets_length = targets_length + 1
307307
self.num_noise_spans = num_noise_spans
308308

309309
# Build the samples mapping.
@@ -322,11 +322,20 @@ def __init__(
322322
tokenizer = get_tokenizer()
323323
self.sep_id = tokenizer.sep
324324
self.sentinel_token_ids = tokenizer.additional_special_tokens_ids
325+
assert self.sep_id is not None, "MLM dataset requires tokenizer to have a <sep> token"
325326
assert len(self.sentinel_token_ids) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
326327
assert len(self.sentinel_token_ids) >= self.num_noise_spans, "Not enough sentinel tokens, please add more"
327328

329+
args = get_args()
330+
if hasattr(args, "encoder_seq_length") and args.encoder_seq_length is not None:
331+
# T5 style
332+
assert self.inputs_length == args.encoder_seq_length
333+
assert self.targets_length == args.decoder_seq_length + 1
334+
else:
335+
assert self.inputs_length + self.targets_length == args.seq_length
336+
328337
def __len__(self):
329-
return len(self.samples_mapping)
338+
return len(self._gpt_dataset)
330339

331340
def __getitem__(self, idx):
332341
if isinstance(idx, slice):

0 commit comments

Comments
 (0)