| 
 | 1 | +# coding=utf-8  | 
 | 2 | +# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.  | 
 | 3 | +#  | 
 | 4 | +# Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 5 | +# you may not use this file except in compliance with the License.  | 
 | 6 | +# You may obtain a copy of the License at  | 
 | 7 | +#  | 
 | 8 | +#     http://www.apache.org/licenses/LICENSE-2.0  | 
 | 9 | +#  | 
 | 10 | +# Unless required by applicable law or agreed to in writing, software  | 
 | 11 | +# distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 13 | +# See the License for the specific language governing permissions and  | 
 | 14 | +# limitations under the License.  | 
 | 15 | + | 
 | 16 | +"""UL2-style dataset."""  | 
 | 17 | + | 
 | 18 | +import numpy as np  | 
 | 19 | + | 
 | 20 | +from megatron import get_tokenizer  | 
 | 21 | +from megatron.data.dataset_utils import (  | 
 | 22 | +    create_masked_lm_predictions,  | 
 | 23 | +    get_samples_mapping,  | 
 | 24 | +    SamplingStyle  | 
 | 25 | +)  | 
 | 26 | +from megatron.data.t5_dataset import pad_and_convert_to_numpy, T5Dataset  | 
 | 27 | + | 
 | 28 | + | 
 | 29 | +class UL2Dataset(T5Dataset):  | 
 | 30 | + | 
 | 31 | +    def __init__(self, name, indexed_dataset, data_prefix,  | 
 | 32 | +                 num_epochs, max_num_samples, denoiser_ratios,  | 
 | 33 | +                 denoisers, mean_span_lengths, mask_ratios,  | 
 | 34 | +                 denoiser_tokens, max_seq_length, max_seq_length_dec,  | 
 | 35 | +                 short_seq_prob, seed):  | 
 | 36 | + | 
 | 37 | +        if denoiser_ratios is None:  | 
 | 38 | +            # Uniform  | 
 | 39 | +            denoiser_ratios = [1 / len(denoisers)] * len(denoisers)  | 
 | 40 | + | 
 | 41 | +        assert (  | 
 | 42 | +            len(denoiser_ratios) == len(denoisers)  | 
 | 43 | +            == len(mean_span_lengths) == len(mask_ratios)  | 
 | 44 | +        ), (  | 
 | 45 | +            'some UL2 configurations do not correspond to the amount of '  | 
 | 46 | +            'denoising objectives'  | 
 | 47 | +        )  | 
 | 48 | + | 
 | 49 | +        # Params to store.  | 
 | 50 | +        self.name = name  | 
 | 51 | +        self.seed = seed  | 
 | 52 | +        self.denoiser_ratios = [  | 
 | 53 | +            denoiser_ratio / sum(denoiser_ratios)  | 
 | 54 | +            for denoiser_ratio in denoiser_ratios  | 
 | 55 | +        ]  | 
 | 56 | +        self.denoisers = [denoiser.upper() for denoiser in denoisers]  | 
 | 57 | +        self.mean_span_lengths = mean_span_lengths  | 
 | 58 | +        self.mask_ratios = mask_ratios  | 
 | 59 | +        self.max_seq_length = max_seq_length  | 
 | 60 | +        self.max_seq_length_dec = max_seq_length_dec  | 
 | 61 | + | 
 | 62 | +        # Dataset.  | 
 | 63 | +        self.indexed_dataset = indexed_dataset  | 
 | 64 | + | 
 | 65 | +        # Build the samples mapping.  | 
 | 66 | +        self.samples_mapping = get_samples_mapping(self.indexed_dataset,  | 
 | 67 | +                                                   data_prefix,  | 
 | 68 | +                                                   num_epochs,  | 
 | 69 | +                                                   max_num_samples,  | 
 | 70 | +                                                   self.max_seq_length - 2, # account for added tokens  | 
 | 71 | +                                                   short_seq_prob,  | 
 | 72 | +                                                   self.seed,  | 
 | 73 | +                                                   self.name,  | 
 | 74 | +                                                   False)  | 
 | 75 | + | 
 | 76 | +        # Vocab stuff.  | 
 | 77 | +        tokenizer = get_tokenizer()  | 
 | 78 | +        self.vocab_id_list = list(tokenizer.inv_vocab.keys())  | 
 | 79 | +        self.vocab_id_to_token_dict = tokenizer.inv_vocab  | 
 | 80 | +        self.cls_ids = {  | 
 | 81 | +            denoiser: tokenizer.vocab[token]  | 
 | 82 | +            for (denoiser, token) in denoiser_tokens.items()  | 
 | 83 | +        }  | 
 | 84 | +        # cls_token = self.vocab_id_to_token_dict[tokenizer.cls]  | 
 | 85 | +        # if cls_token not in self.cls_ids:  | 
 | 86 | +        #     self.cls_ids[cls_token] = tokenizer.cls  | 
 | 87 | +        self.sep_id = tokenizer.sep  | 
 | 88 | +        self.mask_id = tokenizer.mask  | 
 | 89 | +        self.pad_id = tokenizer.pad  | 
 | 90 | +        self.bos_id = tokenizer.bos_token_id  | 
 | 91 | +        self.eos_id = tokenizer.eos_token_id  | 
 | 92 | +        # Filter out denoiser tokens  | 
 | 93 | +        self.sentinel_tokens = [  | 
 | 94 | +            token  | 
 | 95 | +            for token in tokenizer.additional_special_tokens_ids  | 
 | 96 | +            if token not in self.cls_ids.values()  | 
 | 97 | +        ]  | 
 | 98 | +        assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script"  | 
 | 99 | + | 
 | 100 | +    def __len__(self):  | 
 | 101 | +        return self.samples_mapping.shape[0]  | 
 | 102 | + | 
 | 103 | +    def __getitem__(self, idx):  | 
 | 104 | + | 
 | 105 | +        start_index, end_index, seq_length = self.samples_mapping[idx]  | 
 | 106 | +        sample = []  | 
 | 107 | +        for index in range(start_index, end_index):  | 
 | 108 | +            sample.append(self.indexed_dataset[index])  | 
 | 109 | +        # Note that this rng state should be numpy and not python since  | 
 | 110 | +        # python randint is inclusive whereas the numpy one is exclusive.  | 
 | 111 | +        np_rng = np.random.RandomState(seed=(self.seed + idx))  | 
 | 112 | +        return build_training_sample(sample, seq_length,  | 
 | 113 | +                                     self.max_seq_length,  # needed for padding  | 
 | 114 | +                                     self.max_seq_length_dec,  | 
 | 115 | +                                     self.vocab_id_list,  | 
 | 116 | +                                     self.vocab_id_to_token_dict,  | 
 | 117 | +                                     self.cls_ids, self.sep_id,  | 
 | 118 | +                                     self.mask_id, self.pad_id,  | 
 | 119 | +                                     self.denoiser_ratios, self.denoisers,  | 
 | 120 | +                                     self.mean_span_lengths, self.mask_ratios,  | 
 | 121 | +                                     np_rng,  | 
 | 122 | +                                     self.bos_id, self.eos_id,  | 
 | 123 | +                                     self.sentinel_tokens)  | 
 | 124 | + | 
 | 125 | + | 
 | 126 | +def build_training_sample(sample, target_seq_length,  | 
 | 127 | +                          max_seq_length, max_seq_length_dec,  | 
 | 128 | +                          vocab_id_list, vocab_id_to_token_dict,  | 
 | 129 | +                          cls_ids, sep_id, mask_id, pad_id,  | 
 | 130 | +                          denoiser_ratios, denoisers,  | 
 | 131 | +                          mean_span_lengths, mask_ratios,  | 
 | 132 | +                          np_rng, bos_id=None,  | 
 | 133 | +                          eos_id=None, sentinel_tokens=None):  | 
 | 134 | +    """Build training sample.  | 
 | 135 | +
  | 
 | 136 | +    Arguments:  | 
 | 137 | +        sample: A list of sentences in which each sentence is a list token ids.  | 
 | 138 | +        target_seq_length: Desired sequence length.  | 
 | 139 | +        max_seq_length: Maximum length of the sequence. All values are padded to  | 
 | 140 | +            this length.  | 
 | 141 | +        vocab_id_list: List of vocabulary ids. Used to pick a random id.  | 
 | 142 | +        vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.  | 
 | 143 | +        cls_ids: Start of example ids.  | 
 | 144 | +        sep_id: Separator id.  | 
 | 145 | +        mask_id: Mask token id.  | 
 | 146 | +        pad_id: Padding token id.  | 
 | 147 | +        denoiser_ratios: Probability of each denoising objective to be selected.  | 
 | 148 | +        denoisers: What type of UL2 denoising objective the other UL2  | 
 | 149 | +              configurations refer to.  | 
 | 150 | +        mean_span_lengths: Mean length for sampling span lengths. Numbers < 1  | 
 | 151 | +              indicate a mean length of the sequence length times that number.  | 
 | 152 | +        mask_ratios: Ratio of masked token in the full sequence.  | 
 | 153 | +        np_rng: Random number genenrator. Note that this rng state should be  | 
 | 154 | +              numpy and not python since python randint is inclusive for  | 
 | 155 | +              the opper bound whereas the numpy one is exclusive.  | 
 | 156 | +        bos_id: start of decoder example id  | 
 | 157 | +        eos_id: end of generation id  | 
 | 158 | +        sentinel_tokens: unique value to be substituted for every replaced span  | 
 | 159 | +    """  | 
 | 160 | + | 
 | 161 | +    assert target_seq_length <= max_seq_length  | 
 | 162 | + | 
 | 163 | +    # flatten sentences into one list  | 
 | 164 | +    tokens = [token for sentence in sample for token in sentence]  | 
 | 165 | + | 
 | 166 | +    # Truncate to `target_sequence_length`.  | 
 | 167 | +    max_num_tokens = target_seq_length  | 
 | 168 | +    truncated = len(tokens) > max_num_tokens  | 
 | 169 | +    tokens = tokens[:max_num_tokens]  | 
 | 170 | + | 
 | 171 | +    # Denoiser selection  | 
 | 172 | +    denoiser_index = np_rng.choice(np.arange(len(denoisers)), p=denoiser_ratios)  | 
 | 173 | +    denoiser = denoisers[denoiser_index]  | 
 | 174 | +    masked_lm_prob = mask_ratios[denoiser_index]  | 
 | 175 | +    mean_ngrams = mean_span_lengths[denoiser_index]  | 
 | 176 | +    if mean_ngrams < 1:  | 
 | 177 | +        mean_ngrams = round(len(tokens) * mean_ngrams)  | 
 | 178 | +    max_ngrams = mean_ngrams * 2 - 1  | 
 | 179 | + | 
 | 180 | +    # Prepend objective token.  | 
 | 181 | +    cls_id = cls_ids.get(denoiser)  | 
 | 182 | +    if cls_id is None:  | 
 | 183 | +        raise ValueError('unknown denoiser')  | 
 | 184 | +    tokens = [cls_id] + tokens  | 
 | 185 | + | 
 | 186 | +    # Masking.  | 
 | 187 | +    max_predictions_per_seq = masked_lm_prob * len(tokens)  | 
 | 188 | +    if denoiser == 'R' or denoiser == 'X':  | 
 | 189 | +        sampling_style = SamplingStyle.NORMAL  | 
 | 190 | +        prefix_lm = False  | 
 | 191 | +    elif denoiser == 'S':  | 
 | 192 | +        sampling_style = SamplingStyle.UNIFORM  | 
 | 193 | +        prefix_lm = True  | 
 | 194 | +    else:  | 
 | 195 | +        raise ValueError('unknown denoiser')  | 
 | 196 | +    (  | 
 | 197 | +        tokens, masked_positions, masked_labels, _, masked_spans,  | 
 | 198 | +    ) = create_masked_lm_predictions(  | 
 | 199 | +        tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,  | 
 | 200 | +        cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng,  | 
 | 201 | +        max_ngrams=max_ngrams, masking_style="t5",  | 
 | 202 | +        sampling_style=sampling_style, prefix_lm=prefix_lm,  | 
 | 203 | +    )  | 
 | 204 | + | 
 | 205 | +    # Padding.  | 
 | 206 | +    tokens_enc, tokens_dec_in, labels, enc_mask, \  | 
 | 207 | +    dec_mask, enc_dec_mask, loss_mask \  | 
 | 208 | +        = pad_and_convert_to_numpy(tokens, masked_positions,  | 
 | 209 | +                                   masked_labels, pad_id, max_seq_length,  | 
 | 210 | +                                   max_seq_length_dec, masked_spans,  | 
 | 211 | +                                   bos_id, eos_id, sentinel_tokens)  | 
 | 212 | + | 
 | 213 | +    train_sample = {  | 
 | 214 | +        'text_enc': tokens_enc,  | 
 | 215 | +        'text_dec': tokens_dec_in,  | 
 | 216 | +        'labels': labels,  | 
 | 217 | +        'loss_mask': loss_mask,  | 
 | 218 | +        'truncated': int(truncated),  | 
 | 219 | +        'enc_mask': enc_mask,  | 
 | 220 | +        'dec_mask': dec_mask,  | 
 | 221 | +        'enc_dec_mask': enc_dec_mask,  | 
 | 222 | +    }  | 
 | 223 | +    return train_sample  | 
0 commit comments