|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""UL2-style dataset.""" |
| 17 | + |
| 18 | +import numpy as np |
| 19 | + |
| 20 | +from megatron import get_tokenizer |
| 21 | +from megatron.data.dataset_utils import ( |
| 22 | + create_masked_lm_predictions, |
| 23 | + get_samples_mapping, |
| 24 | + SamplingStyle |
| 25 | +) |
| 26 | +from megatron.data.t5_dataset import pad_and_convert_to_numpy, T5Dataset |
| 27 | + |
| 28 | + |
| 29 | +class UL2Dataset(T5Dataset): |
| 30 | + |
| 31 | + def __init__(self, name, indexed_dataset, data_prefix, |
| 32 | + num_epochs, max_num_samples, denoiser_ratios, |
| 33 | + denoisers, mean_span_lengths, mask_ratios, |
| 34 | + denoiser_tokens, max_seq_length, max_seq_length_dec, |
| 35 | + short_seq_prob, seed): |
| 36 | + |
| 37 | + if denoiser_ratios is None: |
| 38 | + # Uniform distribution by default. |
| 39 | + denoiser_ratios = [1 / len(denoisers)] * len(denoisers) |
| 40 | + |
| 41 | + assert ( |
| 42 | + len(denoiser_ratios) == len(denoisers) |
| 43 | + == len(mean_span_lengths) == len(mask_ratios) |
| 44 | + ), ( |
| 45 | + 'some UL2 configurations do not correspond to the amount of ' |
| 46 | + 'denoising objectives' |
| 47 | + ) |
| 48 | + |
| 49 | + super().__init__(name, indexed_dataset, data_prefix, |
| 50 | + num_epochs, max_num_samples, masked_lm_prob, |
| 51 | + max_seq_length, max_seq_length_dec, |
| 52 | + short_seq_prob, seed) |
| 53 | + |
| 54 | + # Params to store. |
| 55 | + self.denoiser_ratios = [ |
| 56 | + denoiser_ratio / sum(denoiser_ratios) |
| 57 | + for denoiser_ratio in denoiser_ratios |
| 58 | + ] |
| 59 | + self.denoisers = [denoiser.upper() for denoiser in denoisers] |
| 60 | + self.mean_span_lengths = mean_span_lengths |
| 61 | + self.mask_ratios = mask_ratios |
| 62 | + |
| 63 | + # Vocab stuff. |
| 64 | + tokenizer = get_tokenizer() |
| 65 | + # Remove CLS token because we don't need it. |
| 66 | + del self.cls_id |
| 67 | + self.cls_ids = { |
| 68 | + denoiser: tokenizer.vocab[token] |
| 69 | + for (denoiser, token) in denoiser_tokens.items() |
| 70 | + } |
| 71 | + # cls_token = self.vocab_id_to_token_dict[tokenizer.cls] |
| 72 | + # if cls_token not in self.cls_ids: |
| 73 | + # self.cls_ids[cls_token] = tokenizer.cls |
| 74 | + |
| 75 | + # Filter out denoiser tokens. |
| 76 | + self.sentinel_tokens = [ |
| 77 | + token |
| 78 | + for token in tokenizer.additional_special_tokens_ids |
| 79 | + if token not in self.cls_ids.values() |
| 80 | + ] |
| 81 | + assert len(self.sentinel_tokens) > 0, \ |
| 82 | + "Provide the argument --vocab-extra-ids 100 to the script" |
| 83 | + |
| 84 | + def __getitem__(self, idx): |
| 85 | + |
| 86 | + start_index, end_index, seq_length = self.samples_mapping[idx] |
| 87 | + sample = [] |
| 88 | + for index in range(start_index, end_index): |
| 89 | + sample.append(self.indexed_dataset[index]) |
| 90 | + # Note that this rng state should be numpy and not python since |
| 91 | + # python randint is inclusive whereas the numpy one is exclusive. |
| 92 | + np_rng = np.random.RandomState(seed=(self.seed + idx)) |
| 93 | + return build_training_sample(sample, seq_length, |
| 94 | + self.max_seq_length, # needed for padding |
| 95 | + self.max_seq_length_dec, |
| 96 | + self.vocab_id_list, |
| 97 | + self.vocab_id_to_token_dict, |
| 98 | + self.cls_ids, self.sep_id, |
| 99 | + self.mask_id, self.pad_id, |
| 100 | + self.denoiser_ratios, self.denoisers, |
| 101 | + self.mean_span_lengths, self.mask_ratios, |
| 102 | + np_rng, |
| 103 | + self.bos_id, self.eos_id, |
| 104 | + self.sentinel_tokens) |
| 105 | + |
| 106 | + |
| 107 | +def build_training_sample(sample, target_seq_length, |
| 108 | + max_seq_length, max_seq_length_dec, |
| 109 | + vocab_id_list, vocab_id_to_token_dict, |
| 110 | + cls_ids, sep_id, mask_id, pad_id, |
| 111 | + denoiser_ratios, denoisers, |
| 112 | + mean_span_lengths, mask_ratios, |
| 113 | + np_rng, bos_id=None, |
| 114 | + eos_id=None, sentinel_tokens=None): |
| 115 | + """Build training sample. |
| 116 | +
|
| 117 | + Arguments: |
| 118 | + sample: A list of sentences in which each sentence is a list token ids. |
| 119 | + target_seq_length: Desired sequence length. |
| 120 | + max_seq_length: Maximum length of the sequence. All values are padded to |
| 121 | + this length. |
| 122 | + vocab_id_list: List of vocabulary ids. Used to pick a random id. |
| 123 | + vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. |
| 124 | + cls_ids: Start of example ids. |
| 125 | + sep_id: Separator id. |
| 126 | + mask_id: Mask token id. |
| 127 | + pad_id: Padding token id. |
| 128 | + denoiser_ratios: Probability of each denoising objective to be selected. |
| 129 | + denoisers: What type of UL2 denoising objective the other UL2 |
| 130 | + configurations refer to. |
| 131 | + mean_span_lengths: Mean length for sampling span lengths. Numbers < 1 |
| 132 | + indicate a mean length of the sequence length times that number. |
| 133 | + mask_ratios: Ratio of masked token in the full sequence. |
| 134 | + np_rng: Random number genenrator. Note that this rng state should be |
| 135 | + numpy and not python since python randint is inclusive for |
| 136 | + the opper bound whereas the numpy one is exclusive. |
| 137 | + bos_id: start of decoder example id |
| 138 | + eos_id: end of generation id |
| 139 | + sentinel_tokens: unique value to be substituted for every replaced span |
| 140 | + """ |
| 141 | + |
| 142 | + assert target_seq_length <= max_seq_length |
| 143 | + |
| 144 | + # flatten sentences into one list |
| 145 | + tokens = [token for sentence in sample for token in sentence] |
| 146 | + |
| 147 | + # Truncate to `target_sequence_length`. |
| 148 | + max_num_tokens = target_seq_length |
| 149 | + truncated = len(tokens) > max_num_tokens |
| 150 | + tokens = tokens[:max_num_tokens] |
| 151 | + |
| 152 | + # Denoiser selection |
| 153 | + denoiser_index = np_rng.choice(np.arange(len(denoisers)), p=denoiser_ratios) |
| 154 | + denoiser = denoisers[denoiser_index] |
| 155 | + masked_lm_prob = mask_ratios[denoiser_index] |
| 156 | + mean_ngrams = mean_span_lengths[denoiser_index] |
| 157 | + if mean_ngrams < 1: |
| 158 | + mean_ngrams = round(len(tokens) * mean_ngrams) |
| 159 | + max_ngrams = mean_ngrams * 2 - 1 |
| 160 | + |
| 161 | + # Prepend objective token. |
| 162 | + cls_id = cls_ids.get(denoiser) |
| 163 | + if cls_id is None: |
| 164 | + raise ValueError('unknown denoiser') |
| 165 | + tokens = [cls_id] + tokens |
| 166 | + |
| 167 | + # Masking. |
| 168 | + max_predictions_per_seq = masked_lm_prob * len(tokens) |
| 169 | + if denoiser == 'R' or denoiser == 'X': |
| 170 | + sampling_style = SamplingStyle.NORMAL |
| 171 | + prefix_lm = False |
| 172 | + elif denoiser == 'S': |
| 173 | + sampling_style = SamplingStyle.UNIFORM |
| 174 | + prefix_lm = True |
| 175 | + else: |
| 176 | + raise ValueError('unknown denoiser') |
| 177 | + ( |
| 178 | + tokens, masked_positions, masked_labels, _, masked_spans, |
| 179 | + ) = create_masked_lm_predictions( |
| 180 | + tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, |
| 181 | + cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng, |
| 182 | + max_ngrams=max_ngrams, masking_style="t5", |
| 183 | + sampling_style=sampling_style, prefix_lm=prefix_lm, |
| 184 | + ) |
| 185 | + |
| 186 | + # Padding. |
| 187 | + tokens_enc, tokens_dec_in, labels, enc_mask, \ |
| 188 | + dec_mask, enc_dec_mask, loss_mask \ |
| 189 | + = pad_and_convert_to_numpy(tokens, masked_positions, |
| 190 | + masked_labels, pad_id, max_seq_length, |
| 191 | + max_seq_length_dec, masked_spans, |
| 192 | + bos_id, eos_id, sentinel_tokens) |
| 193 | + |
| 194 | + train_sample = { |
| 195 | + 'text_enc': tokens_enc, |
| 196 | + 'text_dec': tokens_dec_in, |
| 197 | + 'labels': labels, |
| 198 | + 'loss_mask': loss_mask, |
| 199 | + 'truncated': int(truncated), |
| 200 | + 'enc_mask': enc_mask, |
| 201 | + 'dec_mask': dec_mask, |
| 202 | + 'enc_dec_mask': enc_dec_mask, |
| 203 | + } |
| 204 | + return train_sample |
0 commit comments