1515
1616"""UL2-style dataset."""
1717
18+ import math
19+
1820import numpy as np
1921
2022from megatron import get_tokenizer
2325 get_samples_mapping ,
2426 SamplingStyle
2527)
26- from megatron .data .t5_dataset import pad_and_convert_to_numpy , T5Dataset
28+ from megatron .data .t5_dataset import (
29+ make_history_mask ,
30+ merge_subsequent_masks ,
31+ pad_and_convert_to_numpy ,
32+ T5Dataset ,
33+ )
34+ from megatron .enums import UL2ModelType
35+
36+
37+ def is_decoder_only (ul2_model_type ):
38+ """Return whether we use a decoder-only model."""
39+ assert isinstance (ul2_model_type , UL2ModelType )
40+ return ul2_model_type is not UL2ModelType .ENCODER_DECODER
41+
42+
43+ def is_prefix_lm (ul2_model_type ):
44+ """Return whether we use a non-causal decoder-only model."""
45+ assert isinstance (ul2_model_type , UL2ModelType )
46+ return ul2_model_type is UL2ModelType .NON_CAUSAL_DECODER
2747
2848
2949class UL2Dataset (T5Dataset ):
3050
3151 def __init__ (self , name , indexed_dataset , data_prefix ,
32- num_epochs , max_num_samples , denoiser_ratios ,
33- denoisers , mean_span_lengths , mask_ratios ,
34- denoiser_tokens , max_seq_length , max_seq_length_dec ,
35- short_seq_prob , seed ):
52+ num_epochs , max_num_samples , model_type ,
53+ denoiser_ratios , denoisers , mean_span_lengths ,
54+ mask_ratios , denoiser_tokens , max_seq_length ,
55+ max_seq_length_dec , short_seq_prob , seed ):
3656
3757 if denoiser_ratios is None :
3858 # Uniform distribution by default.
@@ -52,6 +72,7 @@ def __init__(self, name, indexed_dataset, data_prefix,
5272 short_seq_prob , seed )
5373
5474 # Params to store.
75+ self .model_type = model_type
5576 self .denoiser_ratios = [
5677 denoiser_ratio / sum (denoiser_ratios )
5778 for denoiser_ratio in denoiser_ratios
@@ -97,21 +118,21 @@ def __getitem__(self, idx):
97118 self .vocab_id_to_token_dict ,
98119 self .cls_ids , self .sep_id ,
99120 self .mask_id , self .pad_id ,
100- self .denoiser_ratios , self .denoisers ,
101- self .mean_span_lengths , self .mask_ratios ,
102- np_rng ,
103- self .bos_id , self .eos_id ,
104- self .sentinel_tokens )
121+ self .model_type , self .denoiser_ratios ,
122+ self .denoisers , self .mean_span_lengths ,
123+ self .mask_ratios , np_rng , self .bos_id ,
124+ self .eos_id , self .sentinel_tokens )
105125
106126
107127def build_training_sample (sample , target_seq_length ,
108128 max_seq_length , max_seq_length_dec ,
109129 vocab_id_list , vocab_id_to_token_dict ,
110130 cls_ids , sep_id , mask_id , pad_id ,
111- denoiser_ratios , denoisers ,
112- mean_span_lengths , mask_ratios ,
113- np_rng , bos_id = None ,
114- eos_id = None , sentinel_tokens = None ):
131+ model_type , denoiser_ratios ,
132+ denoisers , mean_span_lengths ,
133+ mask_ratios , np_rng ,
134+ bos_id = None , eos_id = None ,
135+ sentinel_tokens = None ):
115136 """Build training sample.
116137
117138 Arguments:
@@ -125,6 +146,7 @@ def build_training_sample(sample, target_seq_length,
125146 sep_id: Separator id.
126147 mask_id: Mask token id.
127148 pad_id: Padding token id.
149+ model_type: What type of model is used.
128150 denoiser_ratios: Probability of each denoising objective to be selected.
129151 denoisers: What type of UL2 denoising objective the other UL2
130152 configurations refer to.
@@ -139,24 +161,28 @@ def build_training_sample(sample, target_seq_length,
139161 sentinel_tokens: unique value to be substituted for every replaced span
140162 """
141163
164+ # Denoiser selection
165+ denoiser_index = np_rng .choice (np .arange (len (denoisers )), p = denoiser_ratios )
166+ denoiser = denoisers [denoiser_index ]
167+ masked_lm_prob = mask_ratios [denoiser_index ]
168+
142169 assert target_seq_length <= max_seq_length
143170
144171 # flatten sentences into one list
145172 tokens = [token for sentence in sample for token in sentence ]
146173
147- # Truncate to `target_sequence_length`.
148174 max_num_tokens = target_seq_length
149- truncated = len ( tokens ) > max_num_tokens
150- tokens = tokens [: max_num_tokens ]
151-
152- # Denoiser selection
153- denoiser_index = np_rng . choice ( np . arange ( len ( denoisers )), p = denoiser_ratios )
154- denoiser = denoisers [ denoiser_index ]
155- masked_lm_prob = mask_ratios [ denoiser_index ]
156- mean_ngrams = mean_span_lengths [ denoiser_index ]
157- if mean_ngrams < 1 :
158- mean_ngrams = round ( len (tokens ) * mean_ngrams )
159- max_ngrams = mean_ngrams * 2 - 1
175+ if is_decoder_only ( model_type ):
176+ # Keep space for repeated `extra_id` tokens; not the most data
177+ # efficient since we calculate this based on the maximum number
178+ # of possible `extra_id` tokens.
179+ safe_max_seq_len = math . floor ( max_num_tokens / ( 1 + masked_lm_prob ) )
180+ truncated = len ( tokens ) > safe_max_seq_len
181+ tokens = tokens [: safe_max_seq_len ]
182+ else :
183+ # Truncate to `target_sequence_length`.
184+ truncated = len (tokens ) > max_num_tokens
185+ tokens = tokens [: max_num_tokens ]
160186
161187 # Prepend objective token.
162188 cls_id = cls_ids .get (denoiser )
@@ -166,6 +192,11 @@ def build_training_sample(sample, target_seq_length,
166192
167193 # Masking.
168194 max_predictions_per_seq = masked_lm_prob * len (tokens )
195+ mean_ngrams = mean_span_lengths [denoiser_index ]
196+ if mean_ngrams < 1 :
197+ mean_ngrams = round (len (tokens ) * mean_ngrams )
198+ max_ngrams = mean_ngrams * 2 - 1
199+
169200 if denoiser == 'R' or denoiser == 'X' :
170201 sampling_style = SamplingStyle .NORMAL
171202 prefix_lm = False
@@ -183,22 +214,64 @@ def build_training_sample(sample, target_seq_length,
183214 sampling_style = sampling_style , prefix_lm = prefix_lm ,
184215 )
185216
186- # Padding.
187- tokens_enc , tokens_dec_in , labels , enc_mask , \
188- dec_mask , enc_dec_mask , loss_mask \
189- = pad_and_convert_to_numpy (tokens , masked_positions ,
190- masked_labels , pad_id , max_seq_length ,
191- max_seq_length_dec , masked_spans ,
192- bos_id , eos_id , sentinel_tokens )
193-
194- train_sample = {
195- 'text_enc' : tokens_enc ,
196- 'text_dec' : tokens_dec_in ,
197- 'labels' : labels ,
198- 'loss_mask' : loss_mask ,
199- 'truncated' : int (truncated ),
200- 'enc_mask' : enc_mask ,
201- 'dec_mask' : dec_mask ,
202- 'enc_dec_mask' : enc_dec_mask ,
203- }
217+ if is_decoder_only (model_type ):
218+ # Concatenate to one sequence.
219+ tokens_enc , tokens_dec_in , labels = merge_subsequent_masks (
220+ tokens , masked_spans , bos_id , eos_id , sentinel_tokens )
221+
222+ # Move EOS tokens to end of sequence.
223+ while tokens_enc [- 1 ] == eos_id :
224+ del tokens_enc [- 1 ]
225+ tokens_dec_in .append (eos_id )
226+ labels .append (eos_id )
227+
228+ num_labels = len (labels )
229+
230+ # Move BOS token to start of sequence.
231+ tokens_dec_in = tokens_dec_in [1 :]
232+ tokens = np .concatenate ([
233+ np .array ([bos_id ], dtype = np .int64 ),
234+ tokens_enc ,
235+ np .array ([sep_id ], dtype = np .int64 ),
236+ tokens_dec_in ,
237+ ])
238+ labels = np .concatenate ([
239+ tokens_enc ,
240+ np .array ([sep_id ], dtype = np .int64 ),
241+ labels ,
242+ ])
243+
244+ loss_mask = np .zeros (len (tokens ), dtype = np .int64 )
245+ loss_mask [- num_labels :] = 1
246+
247+ dec_mask = make_history_mask (tokens )
248+ if is_prefix_lm (model_type ):
249+ dec_mask [:- num_labels , :- num_labels ] = 1
250+
251+ train_sample = {
252+ 'text' : tokens ,
253+ 'labels' : labels ,
254+ 'loss_mask' : loss_mask ,
255+ 'truncated' : int (truncated ),
256+ 'dec_mask' : dec_mask ,
257+ }
258+ else :
259+ # Padding.
260+ tokens_enc , tokens_dec_in , labels , enc_mask , \
261+ dec_mask , enc_dec_mask , loss_mask \
262+ = pad_and_convert_to_numpy (tokens , masked_positions ,
263+ masked_labels , pad_id , max_seq_length ,
264+ max_seq_length_dec , masked_spans ,
265+ bos_id , eos_id , sentinel_tokens )
266+
267+ train_sample = {
268+ 'text_enc' : tokens_enc ,
269+ 'text_dec' : tokens_dec_in ,
270+ 'labels' : labels ,
271+ 'loss_mask' : loss_mask ,
272+ 'truncated' : int (truncated ),
273+ 'enc_mask' : enc_mask ,
274+ 'dec_mask' : dec_mask ,
275+ 'enc_dec_mask' : enc_dec_mask ,
276+ }
204277 return train_sample
0 commit comments