33import numpy as np
44import torch
55
6- from megatron import print_rank_0 , get_tokenizer
6+ from megatron import print_rank_0 , get_tokenizer , get_args
77from megatron .data .blendable_dataset import BlendableDataset
88from megatron .data .dataset_utils import get_datasets_weights_and_num_samples , get_split_by_range_
99from megatron .data .dataset_utils import get_train_valid_test_split_ , get_indexed_dataset_
@@ -296,14 +296,14 @@ def __init__(
296296 # To ensure that the input length is `sequence_length`, we need to increase the maximum length
297297 # according to `noise_density` and `mean_noise_span_length`. We can also define the label length accordingly.
298298 number_of_raw_tokens , inputs_length , targets_length , num_noise_spans = compute_input_and_target_lengths (
299- # +1 is used so that we can compute the as autoregressive systems require us to add one more token.
300- sequence_length = self .sequence_length + 1 ,
299+ sequence_length = self .sequence_length ,
301300 noise_density = self .noise_density ,
302301 mean_noise_span_length = self .mean_noise_span_length
303302 )
304- self .number_of_raw_tokens = number_of_raw_tokens
305303 self .inputs_length = inputs_length
306- self .targets_length = targets_length
304+ # In order to compute loss, we need an extra token at the end.
305+ self .number_of_raw_tokens = number_of_raw_tokens + 1
306+ self .targets_length = targets_length + 1
307307 self .num_noise_spans = num_noise_spans
308308
309309 # Build the samples mapping.
@@ -322,11 +322,20 @@ def __init__(
322322 tokenizer = get_tokenizer ()
323323 self .sep_id = tokenizer .sep
324324 self .sentinel_token_ids = tokenizer .additional_special_tokens_ids
325+ assert self .sep_id is not None , "MLM dataset requires tokenizer to have a <sep> token"
325326 assert len (self .sentinel_token_ids ) > 0 , "Provide the argument --vocab-extra-ids 100 to the script"
326327 assert len (self .sentinel_token_ids ) >= self .num_noise_spans , "Not enough sentinel tokens, please add more"
327328
329+ args = get_args ()
330+ if hasattr (args , "encoder_seq_length" ) and args .encoder_seq_length is not None :
331+ # T5 style
332+ assert self .inputs_length == args .encoder_seq_length
333+ assert self .targets_length == args .decoder_seq_length + 1
334+ else :
335+ assert self .inputs_length + self .targets_length == args .seq_length
336+
328337 def __len__ (self ):
329- return len (self .samples_mapping )
338+ return len (self ._gpt_dataset )
330339
331340 def __getitem__ (self , idx ):
332341 if isinstance (idx , slice ):
0 commit comments