1515
1616"""Dataloaders."""
1717
18+ from functools import partial
1819
20+ import numpy as np
1921import torch
20- import random
21- from megatron import get_args
22+
23+ from megatron import get_args , get_tokenizer
2224from megatron import mpu
25+ from megatron .data .mtf_dataset import MTFDataset
26+
27+
28+ def pack_samples (items , max_seq_len : int , micro_batch_size : int , pad_token : int ):
29+ """
30+ Greedily packs samples.
31+
32+ Items:
33+ [
34+ {
35+ 'input_tokens': array([6, 7]),
36+ 'target_tokens': array([8])
37+ },
38+ {
39+ 'input_tokens': array([3, 4]),
40+ 'target_tokens': array([5])
41+ }
42+ ]
43+
44+ Output:
45+ decoder_target_tokens = [[6, 7, 8, 3, 4, 5, <pad>]]: Concatenation of tokens followed with padding tokens.
46+ decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents.
47+ decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]: `0` depicts inputs, `1` depicts target.
48+ """
49+
50+ decoder_target_tokens = np .full ((micro_batch_size , max_seq_len ), pad_token )
51+ decoder_segment_ids = np .zeros ((micro_batch_size , max_seq_len ))
52+ decoder_causal_attention = np .zeros ((micro_batch_size , max_seq_len ))
53+
54+ batch_num = 0
55+ # `0` is reserved for padding
56+ item_num = 1
57+ cur_len = 0
58+ for token_dict in items :
59+ input_token_len = len (token_dict ["input_tokens" ])
60+ target_token_len = len (token_dict ["target_tokens" ])
61+ total_len = input_token_len + target_token_len
62+ if cur_len + total_len > max_seq_len :
63+ len_diff = max_seq_len - cur_len
64+ # Padding
65+ if len_diff > 0 :
66+ decoder_target_tokens [batch_num ][cur_len : max_seq_len ] = pad_token
67+ decoder_segment_ids [batch_num ][cur_len : max_seq_len ] = 0
68+ decoder_causal_attention [batch_num ][cur_len : max_seq_len ] = 0
69+ batch_num += 1
70+ assert batch_num < micro_batch_size
71+ item_num = 1
72+ cur_len = 0
73+
74+ decoder_target_tokens [batch_num ][cur_len : cur_len + input_token_len ] = token_dict ["input_tokens" ]
75+ decoder_target_tokens [batch_num ][cur_len + input_token_len : cur_len + total_len ] = token_dict ["target_tokens" ]
76+ decoder_segment_ids [batch_num ][cur_len : cur_len + total_len ] = item_num
77+ decoder_causal_attention [batch_num ][cur_len : cur_len + input_token_len ] = 1 # input
78+ decoder_causal_attention [batch_num ][cur_len + input_token_len : cur_len + total_len ] = 0 # target
79+
80+ item_num += 1
81+ cur_len += total_len
82+ assert cur_len < max_seq_len
83+
84+ return {
85+ "decoder_target_tokens" : decoder_target_tokens ,
86+ "decoder_segment_ids" : decoder_segment_ids ,
87+ "decoder_causal_attention" : decoder_causal_attention ,
88+ }
2389
2490
2591def build_pretraining_data_loader (dataset , consumed_samples , num_workers = None ):
@@ -44,18 +110,39 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
44110 micro_batch_size = args .micro_batch_size ,
45111 data_parallel_rank = mpu .get_data_parallel_rank (),
46112 data_parallel_size = mpu .get_data_parallel_world_size ())
113+ elif args .dataloader_type == 'decoder_packed' :
114+ assert isinstance (dataset , MTFDataset )
115+ batch_sampler = MegatronDecoderPackedText2TextRandomSampler (
116+ sequence_length = args .seq_length + 1 ,
117+ dataset = dataset ,
118+ total_samples = len (dataset ),
119+ consumed_samples = consumed_samples ,
120+ micro_batch_size = args .micro_batch_size ,
121+ data_parallel_rank = mpu .get_data_parallel_rank (),
122+ data_parallel_size = mpu .get_data_parallel_world_size ())
47123 else :
48124 raise Exception ('{} dataloader type is not supported.' .format (
49- args .dataloader_type ))
125+ args .dataloader_type ))
50126
51127 if num_workers is None :
52128 num_workers = args .num_workers
53129
130+ collate_fn = None
131+ if args .dataloader_type == 'decoder_packed' :
132+ assert isinstance (dataset , MTFDataset )
133+ pad_token = get_tokenizer ().pad
134+ collate_fn = partial (pack_samples , max_seq_len = args .seq_length + 1 , micro_batch_size = args .micro_batch_size ,
135+ pad_token = pad_token )
136+
54137 # Torch dataloader.
55- return torch .utils .data .DataLoader (dataset ,
56- batch_sampler = batch_sampler ,
57- num_workers = num_workers ,
58- pin_memory = True )
138+ return torch .utils .data .DataLoader (
139+ dataset ,
140+ batch_sampler = batch_sampler ,
141+ num_workers = num_workers ,
142+ collate_fn = collate_fn ,
143+ pin_memory = True
144+ )
145+
59146
60147class MegatronPretrainingSampler :
61148
@@ -141,7 +228,7 @@ def __iter__(self):
141228
142229 # data sharding and random sampling
143230 bucket_size = (self .total_samples // self .micro_batch_times_data_parallel_size ) \
144- * self .micro_batch_size
231+ * self .micro_batch_size
145232 bucket_offset = current_epoch_samples // self .data_parallel_size
146233 start_idx = self .data_parallel_rank * bucket_size
147234
@@ -158,3 +245,76 @@ def __iter__(self):
158245 self .consumed_samples += self .micro_batch_times_data_parallel_size
159246 yield batch
160247 batch = []
248+
249+
250+ class MegatronDecoderPackedText2TextRandomSampler (object ):
251+ """
252+ Converts a two stream dataset with `input_tokens` and `target_tokens` and creates a batch that should be greedily
253+ packed to be passed onto the decoder model.
254+
255+ To be used with `pack_samples` as collate_fn
256+ """
257+
258+ def __init__ (self , sequence_length , dataset , total_samples , consumed_samples , micro_batch_size ,
259+ data_parallel_rank , data_parallel_size ):
260+ # Keep a copy of input params for later use.
261+ self .dataset = dataset
262+ self .sequence_length = sequence_length
263+ self .total_samples = total_samples
264+ self .consumed_samples = consumed_samples
265+ self .micro_batch_size = micro_batch_size
266+ self .data_parallel_rank = data_parallel_rank
267+ self .data_parallel_size = data_parallel_size
268+ self .micro_batch_times_data_parallel_size = \
269+ self .micro_batch_size * data_parallel_size
270+ self .last_batch_size = \
271+ self .total_samples % self .micro_batch_times_data_parallel_size
272+
273+ # Sanity checks.
274+ assert self .total_samples > 0 , \
275+ 'no sample to consume: {}' .format (self .total_samples )
276+ assert self .micro_batch_size > 0
277+ assert data_parallel_size > 0
278+ assert self .data_parallel_rank < data_parallel_size , \
279+ 'data_parallel_rank should be smaller than data size: {}, ' \
280+ '{}' .format (self .data_parallel_rank , data_parallel_size )
281+
282+ def __len__ (self ):
283+ return self .total_samples
284+
285+ def __iter__ (self ):
286+ active_total_samples = self .total_samples - self .last_batch_size
287+ self .epoch = self .consumed_samples // active_total_samples
288+ current_epoch_samples = self .consumed_samples % active_total_samples
289+ assert current_epoch_samples % self .micro_batch_times_data_parallel_size == 0
290+
291+ # data sharding and random sampling
292+ bucket_size = (self .total_samples // self .micro_batch_times_data_parallel_size ) \
293+ * self .micro_batch_size
294+ bucket_offset = current_epoch_samples // self .data_parallel_size
295+ start_idx = self .data_parallel_rank * bucket_size
296+
297+ g = torch .Generator ()
298+ g .manual_seed (self .epoch )
299+
300+ random_idx = torch .randperm (bucket_size , generator = g ).tolist ()
301+ idx_range = [start_idx + x for x in random_idx [bucket_offset :]]
302+
303+ batch = []
304+ batch_count = 0
305+ token_lens = 0
306+ # Last batch if not complete will be dropped.
307+ for idx in idx_range :
308+ tok_len = len (self .dataset [idx ]['input_tokens' ]) + len (self .dataset [idx ]['target_tokens' ])
309+ if token_lens + tok_len > self .sequence_length :
310+ batch_count += 1
311+ token_lens = 0
312+
313+ if batch_count == self .micro_batch_size :
314+ self .consumed_samples += self .micro_batch_times_data_parallel_size
315+ yield batch
316+ batch_count = 0
317+ batch = []
318+ else :
319+ token_lens += tok_len
320+ batch .append (idx )
0 commit comments