1515
1616"""Dataloaders."""
1717
18- from functools import partial
19-
20- import numpy as np
2118import torch
2219
23- from megatron import get_args , get_tokenizer
20+ from megatron import get_args
2421from megatron import mpu
25- from megatron .data .mtf_dataset import MTFDataset
26-
27-
28- def pack_samples (items , max_seq_len : int , micro_batch_size : int , pad_token : int ):
29- """
30- Greedily packs samples.
31-
32- Items:
33- [
34- {
35- 'input_tokens': array([6, 7]),
36- 'target_tokens': array([8])
37- },
38- {
39- 'input_tokens': array([3, 4]),
40- 'target_tokens': array([5])
41- }
42- ]
43-
44- Output:
45- decoder_target_tokens = [[6, 7, 8, 3, 4, 5, <pad>]]: Concatenation of tokens followed with padding tokens.
46- decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents.
47- decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]: `0` depicts inputs, `1` depicts target.
48- """
49-
50- decoder_target_tokens = np .full ((micro_batch_size , max_seq_len ), pad_token )
51- decoder_segment_ids = np .zeros ((micro_batch_size , max_seq_len ))
52- decoder_causal_attention = np .zeros ((micro_batch_size , max_seq_len ))
53-
54- batch_num = 0
55- # `0` is reserved for padding
56- item_num = 1
57- cur_len = 0
58- for token_dict in items :
59- input_token_len = len (token_dict ["input_tokens" ])
60- target_token_len = len (token_dict ["target_tokens" ])
61- total_len = input_token_len + target_token_len
62- if cur_len + total_len > max_seq_len :
63- len_diff = max_seq_len - cur_len
64- # Padding
65- if len_diff > 0 :
66- decoder_target_tokens [batch_num ][cur_len : max_seq_len ] = pad_token
67- decoder_segment_ids [batch_num ][cur_len : max_seq_len ] = 0
68- decoder_causal_attention [batch_num ][cur_len : max_seq_len ] = 0
69- batch_num += 1
70- assert batch_num < micro_batch_size
71- item_num = 1
72- cur_len = 0
73-
74- decoder_target_tokens [batch_num ][cur_len : cur_len + input_token_len ] = token_dict ["input_tokens" ]
75- decoder_target_tokens [batch_num ][cur_len + input_token_len : cur_len + total_len ] = token_dict ["target_tokens" ]
76- decoder_segment_ids [batch_num ][cur_len : cur_len + total_len ] = item_num
77- decoder_causal_attention [batch_num ][cur_len : cur_len + input_token_len ] = 1 # input
78- decoder_causal_attention [batch_num ][cur_len + input_token_len : cur_len + total_len ] = 0 # target
79-
80- item_num += 1
81- cur_len += total_len
82- assert cur_len < max_seq_len
83-
84- return {
85- "decoder_target_tokens" : decoder_target_tokens ,
86- "decoder_segment_ids" : decoder_segment_ids ,
87- "decoder_causal_attention" : decoder_causal_attention ,
88- }
22+ from megatron .data .decoder_packed_mtf_dataset import DecoderPackedMTFDataset
8923
9024
9125def build_pretraining_data_loader (dataset , consumed_samples , num_workers = None ):
@@ -110,41 +44,23 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
11044 micro_batch_size = args .micro_batch_size ,
11145 data_parallel_rank = mpu .get_data_parallel_rank (),
11246 data_parallel_size = mpu .get_data_parallel_world_size ())
113- elif args .dataloader_type == 'decoder_packed' :
114- assert isinstance (dataset , MTFDataset )
115- batch_sampler = MegatronDecoderPackedText2TextRandomSampler (
116- sequence_length = args .seq_length + 1 ,
117- dataset = dataset ,
118- total_samples = len (dataset ),
119- consumed_samples = consumed_samples ,
120- micro_batch_size = args .micro_batch_size ,
121- data_parallel_rank = mpu .get_data_parallel_rank (),
122- data_parallel_size = mpu .get_data_parallel_world_size ())
12347 else :
12448 raise Exception ('{} dataloader type is not supported.' .format (
12549 args .dataloader_type ))
12650
12751 if num_workers is None :
12852 num_workers = args .num_workers
12953
130- collate_fn = None
131- if args .dataloader_type == 'decoder_packed' :
132- assert isinstance (dataset , MTFDataset )
133- pad_token = get_tokenizer ().pad
134- collate_fn = partial (pack_samples , max_seq_len = args .seq_length + 1 , micro_batch_size = args .micro_batch_size ,
135- pad_token = pad_token )
136-
13754 # Torch dataloader.
13855 return torch .utils .data .DataLoader (
13956 dataset ,
14057 batch_sampler = batch_sampler ,
14158 num_workers = num_workers ,
14259 generator = torch .Generator ().manual_seed (args .seed ),
143- collate_fn = collate_fn ,
60+ collate_fn = None ,
14461 pin_memory = True
14562 )
14663
147-
14864class MegatronPretrainingSampler :
14965
15066 def __init__ (self , total_samples , consumed_samples , micro_batch_size ,
@@ -246,76 +162,3 @@ def __iter__(self):
246162 self .consumed_samples += self .micro_batch_times_data_parallel_size
247163 yield batch
248164 batch = []
249-
250-
251- class MegatronDecoderPackedText2TextRandomSampler (object ):
252- """
253- Converts a two stream dataset with `input_tokens` and `target_tokens` and creates a batch that should be greedily
254- packed to be passed onto the decoder model.
255-
256- To be used with `pack_samples` as collate_fn
257- """
258-
259- def __init__ (self , sequence_length , dataset , total_samples , consumed_samples , micro_batch_size ,
260- data_parallel_rank , data_parallel_size ):
261- # Keep a copy of input params for later use.
262- self .dataset = dataset
263- self .sequence_length = sequence_length
264- self .total_samples = total_samples
265- self .consumed_samples = consumed_samples
266- self .micro_batch_size = micro_batch_size
267- self .data_parallel_rank = data_parallel_rank
268- self .data_parallel_size = data_parallel_size
269- self .micro_batch_times_data_parallel_size = \
270- self .micro_batch_size * data_parallel_size
271- self .last_batch_size = \
272- self .total_samples % self .micro_batch_times_data_parallel_size
273-
274- # Sanity checks.
275- assert self .total_samples > 0 , \
276- 'no sample to consume: {}' .format (self .total_samples )
277- assert self .micro_batch_size > 0
278- assert data_parallel_size > 0
279- assert self .data_parallel_rank < data_parallel_size , \
280- 'data_parallel_rank should be smaller than data size: {}, ' \
281- '{}' .format (self .data_parallel_rank , data_parallel_size )
282-
283- def __len__ (self ):
284- return self .total_samples
285-
286- def __iter__ (self ):
287- active_total_samples = self .total_samples - self .last_batch_size
288- self .epoch = self .consumed_samples // active_total_samples
289- current_epoch_samples = self .consumed_samples % active_total_samples
290- assert current_epoch_samples % self .micro_batch_times_data_parallel_size == 0
291-
292- # data sharding and random sampling
293- bucket_size = (self .total_samples // self .micro_batch_times_data_parallel_size ) \
294- * self .micro_batch_size
295- bucket_offset = current_epoch_samples // self .data_parallel_size
296- start_idx = self .data_parallel_rank * bucket_size
297-
298- g = torch .Generator ()
299- g .manual_seed (self .epoch )
300-
301- random_idx = torch .randperm (bucket_size , generator = g ).tolist ()
302- idx_range = [start_idx + x for x in random_idx [bucket_offset :]]
303-
304- batch = []
305- batch_count = 0
306- token_lens = 0
307- # Last batch if not complete will be dropped.
308- for idx in idx_range :
309- tok_len = len (self .dataset [idx ]['input_tokens' ]) + len (self .dataset [idx ]['target_tokens' ])
310- if token_lens + tok_len > self .sequence_length :
311- batch_count += 1
312- token_lens = 0
313-
314- if batch_count == self .micro_batch_size :
315- self .consumed_samples += self .micro_batch_times_data_parallel_size
316- yield batch
317- batch_count = 0
318- batch = []
319- else :
320- token_lens += tok_len
321- batch .append (idx )
0 commit comments