@@ -148,7 +148,7 @@ def _num_tokens(documents, lens):
148
148
149
149
150
150
def _num_epochs (tokens_per_epoch , seq_length , num_samples ):
151
- """Based on number of samples and sequence lenght , calculate how many
151
+ """Based on number of samples and sequence length , calculate how many
152
152
epochs will be needed."""
153
153
num_epochs = 0
154
154
total_tokens = 0
@@ -256,18 +256,17 @@ def get_train_valid_test_split_(splits_string, size):
256
256
return splits_index
257
257
258
258
259
- def create_pretrained_dataset (
260
- args ,
261
- input_path ,
262
- local_rank ,
263
- data_world_rank ,
264
- data_world_size ,
265
- eos_id ,
266
- worker_init = None ,
267
- max_seq_len = 1024 ,
268
- places = None ,
269
- data_holders = None ,
270
- pipeline_mode = False , ):
259
+ def create_pretrained_dataset (args ,
260
+ input_path ,
261
+ local_rank ,
262
+ data_world_rank ,
263
+ data_world_size ,
264
+ eos_id ,
265
+ worker_init = None ,
266
+ max_seq_len = 1024 ,
267
+ places = None ,
268
+ data_holders = None ,
269
+ pipeline_mode = False ):
271
270
272
271
if local_rank == 0 :
273
272
start_time = time .time ()
@@ -339,7 +338,8 @@ def build_dataset(index, name, num_samples):
339
338
sample_lens = sample_lens ,
340
339
eos_id = eos_id ,
341
340
seed = args .seed ,
342
- use_pure_fp16 = args .use_amp and args .amp_level == "O2" )
341
+ use_pure_fp16 = args .use_amp and args .amp_level == "O2" ,
342
+ data_holders = data_holders )
343
343
batch_sampler = DistributedBatchSampler (
344
344
dataset ,
345
345
batch_size = args .micro_batch_size ,
@@ -361,14 +361,16 @@ def data_gen():
361
361
data_loader .set_sample_generator (
362
362
data_gen , batch_size = args .micro_batch_size , places = places )
363
363
else :
364
+ stacks = (Stack (), ) * len (data_holders )
365
+ collate_fn = Tuple (* stacks )
364
366
data_loader = DataLoader (
365
367
dataset = dataset ,
366
368
places = places ,
367
369
feed_list = data_holders ,
368
370
batch_sampler = batch_sampler ,
369
371
num_workers = 1 ,
370
372
worker_init_fn = worker_init ,
371
- collate_fn = Tuple ( Stack (), Stack (), Stack (), Stack ()) ,
373
+ collate_fn = collate_fn ,
372
374
return_list = False )
373
375
return data_loader
374
376
@@ -401,7 +403,8 @@ def __init__(self,
401
403
name = "gpt" ,
402
404
max_seq_len = 1024 ,
403
405
seed = 1234 ,
404
- use_pure_fp16 = False ):
406
+ use_pure_fp16 = False ,
407
+ data_holders = None ):
405
408
self .file_prefix = file_prefix
406
409
self .max_seq_len = max_seq_len
407
410
self .name = name
@@ -410,6 +413,7 @@ def __init__(self,
410
413
self .sample_lens = sample_lens
411
414
self .micro_batch_size = micro_batch_size
412
415
self .use_pure_fp16 = use_pure_fp16
416
+ self .data_holders = data_holders
413
417
414
418
if documents is None :
415
419
document_ids = np .arange (0 , self .sample_lens .shape [0 ])
@@ -435,10 +439,17 @@ def _construct_sample(self, tokens):
435
439
else :
436
440
loss_mask = np .ones (seq_length , dtype = "float32" )
437
441
loss_mask [np .where (np .array (tokens ) == self .eos_id )] = 0.0
438
- position_ids = np .arange (0 , seq_length , dtype = "int64" )
439
442
443
+ position_ids = np .arange (0 , seq_length , dtype = "int64" )
440
444
labels = np .array (labels , dtype = "int64" )
441
- return [tokens , loss_mask , position_ids , labels ]
445
+ if len (self .data_holders ) == 4 :
446
+ return [tokens , loss_mask , position_ids , labels ]
447
+ elif len (self .data_holders ) == 3 :
448
+ return [tokens , loss_mask , position_ids ]
449
+ else :
450
+ assert len (self .data_holders ) == 1 , \
451
+ "length of daat_holders should be 4, 3 or 1"
452
+ return [tokens ]
442
453
443
454
def _get_single_sample_from_idx (self , doc_index_f , doc_index_l , offset_f ,
444
455
offset_l ):
0 commit comments