11import torch
2- from functools import partial
32from megatron import get_args
43from megatron import print_rank_0
5- from megatron import get_timers
64from megatron import get_tokenizer
75from megatron import mpu
8- from megatron .data .gpt_dataset import build_train_valid_test_datasets , build_dataset_group
6+ from megatron .data .mlm_dataset import build_train_valid_test_datasets , build_dataset_group
97from megatron .model import SharedT5ModelPipe
108from megatron .training import pretrain
11- from megatron .utils import get_attention_masks_and_position_ids , get_prefix_indices
12- from megatron .utils import average_losses_across_data_parallel_group
9+ from megatron .utils import get_attention_masks_and_position_ids
1310
1411import deepspeed
1512from deepspeed .runtime .utils import see_memory_usage
16- import os
1713
1814try :
1915 from torch .distributed .elastic .multiprocessing .errors import record
@@ -39,24 +35,6 @@ def model_provider(pre_process=True, post_process=True):
3935 # TODO @thomasw21: fix this for PP > 1 (the issue is that you're passing two values that require grad)
4036 assert mpu .get_pipeline_model_parallel_world_size () != 1 , "PP > 1 is not supported yet"
4137
42- # TODO: actually I'm fairly confident that you don't need the causal mask here as it's handled with `AttnMaskType`
43- # # Precompute the attention mask and store it in args. This avoids having to
44- # # pipeline it as an activation during training. The mask is constant, and thus
45- # # we can reuse it.
46- # attention_mask = torch.tril(torch.ones(
47- # (1, args.seq_length, args.seq_length), device=torch.cuda.current_device())).view(
48- # 1, 1, args.seq_length, args.seq_length)
49- #
50- # # Convert attention mask to binary:
51- # attention_mask = (attention_mask < 0.5)
52- # if args.fp16:
53- # attention_mask = attention_mask.half()
54- # elif args.bf16:
55- # attention_mask = attention_mask.bfloat16()
56- #
57- # # must be bool or the training crashes expecting bool, but getting Half
58- # args.attn_mask = attention_mask.to(torch.bool)
59-
6038 model = SharedT5ModelPipe (
6139 num_tokentypes = 0 ,
6240 parallel_output = True
@@ -72,12 +50,10 @@ def model_provider(pre_process=True, post_process=True):
7250
7351def get_batch_pipe (data ):
7452 """Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`"""
75- raise NotImplementedError ("Waiting for MLM data loader to work" )
7653 args = get_args ()
7754 tokenizer = get_tokenizer ()
7855
7956 # Items and their type.
80- # TODO @thomasw21
8157 keys = ["input_tokens" , "target_tokens" ]
8258 datatype = torch .int64
8359
@@ -116,7 +92,7 @@ def get_batch_pipe(data):
11692
11793def train_valid_test_datasets_provider (train_val_test_num_samples ):
11894 """Build train, valid, and test datasets."""
119- raise NotImplementedError ( "Waiting for MLM data loader" )
95+
12096 args = get_args ()
12197 train_ds , valid_ds , test_ds = None , None , None
12298
@@ -129,9 +105,12 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
129105 data_impl = args .data_impl ,
130106 splits_string = args .split ,
131107 train_valid_test_num_samples = train_val_test_num_samples ,
132- seq_length = args .seq_length ,
108+ sequence_length = args .seq_length ,
109+ noise_density = args .noise_density ,
110+ mean_noise_span_length = args .mean_noise_span_length ,
133111 seed = args .seed ,
134- skip_warmup = (not args .mmap_warmup ))
112+ skip_warmup = (not args .mmap_warmup )
113+ )
135114 # Option 2 of data loading using --(train|valid|test)-weighted-split-paths
136115 elif args .train_weighted_split_paths :
137116 assigned_train_valid_test = []
@@ -151,12 +130,20 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
151130 eval (f"args.{ s } _weighted_split_splits" ),
152131 eval (f"args.{ s } _weighted_split_names" ))
153132 for paths , weights , splits , name in data_groups :
154- d = build_dataset_group (name , paths , weights , splits ,
155- args .data_impl ,
156- train_val_test_num_samples ,
157- args .seq_length , args .seed ,
158- (not args .mmap_warmup ),
159- train_valid_test = s )
133+ d = build_dataset_group (
134+ dataset_group_name = name ,
135+ paths = paths ,
136+ weights = weights ,
137+ splits = splits ,
138+ data_impl = args .data_impl ,
139+ train_valid_test_num_samples = train_val_test_num_samples ,
140+ seq_length = args .seq_length ,
141+ noise_density = args .noise_density ,
142+ mean_noise_span_length = args .mean_noise_span_length ,
143+ seed = args .seed ,
144+ skip_warmup = (not args .mmap_warmup ),
145+ train_valid_test = s
146+ )
160147 eval (f"{ s } _ds" ).append (d )
161148 else :
162149 raise NotImplementedError ("No dataloading argument passed" )
0 commit comments