- 
                Notifications
    
You must be signed in to change notification settings  - Fork 228
 
Mlm adaptation #287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mlm adaptation #287
Changes from 174 commits
d2c35fc
              f977b85
              fcfbf17
              870dfd8
              791bbd0
              0f44b92
              2ff0815
              f0a79f6
              eb416c7
              c0bc21b
              82e824c
              7bb17ec
              9929766
              861c41f
              fe95115
              8ea5943
              aa0d146
              215e8cc
              b6eef43
              bfc73a5
              1890f87
              01392a9
              923decb
              4611d67
              4356de3
              f31c686
              a3951e8
              97b9a92
              fe73a73
              6a9cb75
              469848f
              e68283f
              476ae94
              72ff575
              3647291
              fcdc987
              d6fbe78
              c44daba
              a2725d8
              0e94245
              b599ab6
              626b0ae
              6008937
              c1524db
              c59c061
              e677e16
              9ffaeb9
              d0a6a2f
              47fd987
              4f377e8
              5c0bf76
              7c63e4b
              871124c
              55a593d
              adb59ca
              d71afb4
              7b99bb7
              922b09d
              469a02d
              15cb6a0
              5b0bc17
              0671c79
              6db5c9b
              8a58007
              8b0bbc2
              3d1b256
              ce00fd9
              3bcc50c
              76960f7
              229d661
              55e3df7
              05dea6d
              661c8bb
              97d3810
              71388ee
              b0f04d5
              cd43a54
              e0dc666
              866cee1
              0b56a7d
              5bb512b
              31d844f
              1d21963
              1429645
              f5341f8
              b05b175
              59a6e32
              ab76d49
              0d8dfac
              e629224
              efcf50f
              e5eb615
              2eee807
              5840a11
              6d38f73
              430fa6f
              444314f
              26c837d
              feb023c
              f30b9b1
              0a9203a
              672a866
              3780e61
              2130c31
              26afe43
              c1b9816
              453822f
              a62266a
              02dda79
              80331cb
              350227d
              d0eecd4
              243cebe
              da22e0b
              083dce7
              541e9d6
              86bfc8a
              e21a448
              f47d678
              415b8bc
              79bd6f8
              ba19fdf
              d200f4d
              102a461
              e530440
              2568039
              e6b4120
              fd7fe97
              861fc7b
              21c1984
              14e8d0f
              920343f
              a68873d
              5d43986
              79e8c1a
              786d252
              9110520
              7db34b9
              d946515
              bb4e656
              2e7161d
              00473e4
              5992776
              83f5dee
              3235c2d
              5449978
              95c9851
              9ff6172
              451318f
              edfaa19
              5657083
              1cee345
              b4b87fc
              5e80cc1
              253e81f
              1d8a5c0
              e6036a0
              408f16a
              ae87552
              e79c9a2
              62ee550
              a2e9ba8
              64334a4
              7a872c2
              b6f02c5
              86680bc
              4b2d840
              b935b85
              9a74d69
              64b1515
              b210364
              6398d1d
              e0f7c92
              6b92958
              faf0b9e
              0e3ee15
              92070ce
              ea69602
              4dbe448
              8f42790
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,367 @@ | ||
| """Non-Causal Mask Language Model Finetune Style dataset.""" | ||
| 
     | 
||
| import os | ||
| import time | ||
| 
     | 
||
| import numpy as np | ||
| import torch | ||
| 
     | 
||
| from megatron import print_rank_0, get_tokenizer | ||
| from megatron.data.blendable_dataset import BlendableDataset | ||
| from megatron.data.dataset_utils import get_datasets_weights_and_num_samples | ||
| from megatron.data.dataset_utils import get_train_valid_test_split_, get_indexed_dataset_ | ||
| from megatron.data.gpt_dataset import GPTDataset | ||
| 
     | 
||
| 
     | 
||
| def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||
| train_valid_test_num_samples, | ||
| sequence_length, | ||
| noise_density, | ||
| mean_noise_span_length, | ||
| seed, | ||
| skip_warmup | ||
| ): | ||
| assert noise_density is not None | ||
| assert mean_noise_span_length is not None | ||
| 
     | 
||
| if len(data_prefix) == 1: | ||
| return _build_train_valid_test_datasets( | ||
| data_prefix=data_prefix[0], | ||
| data_impl=data_impl, | ||
| splits_string=splits_string, | ||
| train_valid_test_num_samples=train_valid_test_num_samples, | ||
| sequence_length=sequence_length, | ||
| noise_density=noise_density, | ||
| mean_noise_span_length=mean_noise_span_length, | ||
| seed=seed, | ||
| skip_warmup=skip_warmup | ||
| ) | ||
| # Blending dataset. | ||
| # Parse the values. | ||
| output = get_datasets_weights_and_num_samples(data_prefix, | ||
| train_valid_test_num_samples) | ||
| prefixes, weights, datasets_train_valid_test_num_samples = output | ||
| 
     | 
||
| # Build individual datasets. | ||
| train_datasets = [] | ||
| valid_datasets = [] | ||
| test_datasets = [] | ||
| for i in range(len(prefixes)): | ||
| train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( | ||
| data_prefix=prefixes[i], | ||
| data_impl=data_impl, | ||
| splits_string=splits_string, | ||
| train_valid_test_num_samples=datasets_train_valid_test_num_samples[i], | ||
| sequence_length=sequence_length, | ||
| noise_density=noise_density, | ||
| mean_noise_span_length=mean_noise_span_length, | ||
| seed=seed, | ||
| skip_warmup=skip_warmup | ||
| ) | ||
| if train_ds: | ||
| train_datasets.append(train_ds) | ||
| if valid_ds: | ||
| valid_datasets.append(valid_ds) | ||
| if test_ds: | ||
| test_datasets.append(test_ds) | ||
| 
     | 
||
| # Blend. | ||
| blending_train_dataset = None | ||
| if train_datasets: | ||
| blending_train_dataset = BlendableDataset(train_datasets, weights) | ||
| blending_valid_dataset = None | ||
| if valid_datasets: | ||
| blending_valid_dataset = BlendableDataset(valid_datasets, weights) | ||
| blending_test_dataset = None | ||
| if test_datasets: | ||
| blending_test_dataset = BlendableDataset(test_datasets, weights) | ||
| 
     | 
||
| return (blending_train_dataset, blending_valid_dataset, | ||
| blending_test_dataset) | ||
| 
     | 
||
| 
     | 
||
| def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||
| train_valid_test_num_samples, | ||
| sequence_length, | ||
| noise_density, | ||
| mean_noise_span_length, | ||
| seed, | ||
| skip_warmup): | ||
| """Build train, valid, and test datasets.""" | ||
| 
     | 
||
| 
     | 
||
| # Indexed dataset. | ||
| indexed_dataset = get_indexed_dataset_(data_prefix, | ||
| data_impl, | ||
| skip_warmup) | ||
| 
     | 
||
| total_num_of_documents = indexed_dataset.sizes.shape[0] - 1 | ||
| splits = get_train_valid_test_split_(splits_string, total_num_of_documents) | ||
| # Print stats about the splits. | ||
| print_rank_0(' > dataset split:') | ||
| 
     | 
||
| def print_split_stats(name, index): | ||
| print_rank_0(' {}:'.format(name)) | ||
| print_rank_0(' document indices in [{}, {}) total of {} ' | ||
| 'documents'.format(splits[index], splits[index + 1], | ||
| splits[index + 1] - splits[index])) | ||
| start_index = indexed_dataset.doc_idx[splits[index]] | ||
| end_index = indexed_dataset.doc_idx[splits[index + 1]] | ||
| print_rank_0(' sentence indices in [{}, {}) total of {} ' | ||
| 'sentences'.format(start_index, end_index, | ||
| end_index - start_index)) | ||
| print_split_stats('train', 0) | ||
| print_split_stats('validation', 1) | ||
| print_split_stats('test', 2) | ||
| 
     | 
||
| def build_dataset(index, name): | ||
| dataset = None | ||
| if splits[index + 1] > splits[index]: | ||
| # Build the dataset accordingly. | ||
| documents = np.arange(start=splits[index], stop=splits[index + 1], | ||
| step=1, dtype=np.int32) | ||
| dataset = MLMDataset( | ||
| indexed_dataset=indexed_dataset, | ||
| documents=documents, | ||
| noise_density=noise_density, | ||
| mean_noise_span_length=mean_noise_span_length, | ||
| name=name, | ||
| data_prefix=data_prefix, | ||
| sequence_length=sequence_length, | ||
| num_samples=train_valid_test_num_samples[index], | ||
| seed=seed, | ||
| ) | ||
| return dataset | ||
| 
     | 
||
| train_dataset = build_dataset(0, 'train') | ||
| valid_dataset = build_dataset(1, 'valid') | ||
| test_dataset = build_dataset(2, 'test') | ||
| 
     | 
||
| return (train_dataset, valid_dataset, test_dataset) | ||
| 
     | 
||
| 
     | 
||
| class MLMDataset(torch.utils.data.Dataset): | ||
| 
     | 
||
| def __init__( | ||
| self, | ||
| name, | ||
| indexed_dataset, | ||
| documents, | ||
| data_prefix, | ||
| sequence_length, | ||
| num_samples, | ||
| seed, | ||
| noise_density=0.15, | ||
| mean_noise_span_length=3 | ||
| ): | ||
| 
     | 
||
| # Params to store. | ||
| self.name = name | ||
| self.seed = seed | ||
| self.sequence_length = sequence_length | ||
| 
     | 
||
| # Dataset. | ||
| self.indexed_dataset = indexed_dataset | ||
| 
     | 
||
| self.noise_density = noise_density | ||
| self.mean_noise_span_length = mean_noise_span_length | ||
| # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token. | ||
| # To ensure that the input length is `sequence_length`, we need to increase the maximum length | ||
| # according to `noise_density` and `mean_noise_span_length`. We can also define the label length accordingly. | ||
| number_of_raw_tokens, inputs_length, targets_length, num_noise_spans = compute_input_and_target_lengths( | ||
| sequence_length=self.sequence_length, | ||
| noise_density=self.noise_density, | ||
| mean_noise_span_length=self.mean_noise_span_length | ||
| ) | ||
| self.number_of_raw_tokens = number_of_raw_tokens | ||
| self.inputs_length = inputs_length | ||
| self.targets_length = targets_length | ||
| self.num_noise_spans = num_noise_spans | ||
| 
     | 
||
| # Build the samples mapping. | ||
| self._gpt_dataset = GPTDataset( | ||
| name=self.name, | ||
| data_prefix=data_prefix, | ||
| documents=documents, | ||
| indexed_dataset=self.indexed_dataset, | ||
| num_samples=num_samples, | ||
| seq_length=number_of_raw_tokens, | ||
| seed=seed | ||
| ) | ||
| 
     | 
||
| # Vocab stuff. | ||
| tokenizer = get_tokenizer() | ||
| self.sep_id = tokenizer.sep | ||
| self.sentinel_token_ids = tokenizer.additional_special_tokens_ids | ||
| assert len(self.sentinel_token_ids) > 0, "Provide the argument --vocab-extra-ids 100 to the script" | ||
| assert len(self.sentinel_token_ids) >= self.num_noise_spans, "Not enough sentinel tokens, please add more" | ||
| 
     | 
||
| def __len__(self): | ||
| return len(self.samples_mapping) | ||
| 
     | 
||
| def __getitem__(self, idx): | ||
| if isinstance(idx, slice): | ||
| raise NotImplementedError | ||
| 
     | 
||
| sample = self._gpt_dataset[idx]["text"] | ||
| 
     | 
||
| return build_training_sample( | ||
| sample=sample, | ||
| inputs_length=self.inputs_length, | ||
| targets_length=self.targets_length, | ||
| num_noise_spans=self.num_noise_spans, | ||
| sep_id=self.sep_id, | ||
| all_sentinel_token_ids=self.sentinel_token_ids, | ||
| ) | ||
| 
     | 
||
| 
     | 
||
| def build_training_sample( | ||
| sample, | ||
| inputs_length, | ||
| targets_length, | ||
| num_noise_spans, | ||
| sep_id, | ||
| all_sentinel_token_ids, | ||
| ): | ||
| """Build training sample. | ||
| 
     | 
||
| Arguments: | ||
| TODO: Add description | ||
| """ | ||
| 
     | 
||
| spans_start, mask_indices = random_spans_noise_mask( | ||
| inputs_length=inputs_length, | ||
| targets_length=targets_length, | ||
| num_noise_spans=num_noise_spans, | ||
| ) | ||
| spans_end = np.concatenate([ | ||
| spans_start[1:], np.full((1,), len(sample), dtype=np.int32)] | ||
| ) | ||
| 
     | 
||
| sentinel_token_ids = all_sentinel_token_ids[:num_noise_spans] | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given  I wonder if it wouldn't be better to make  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also have a strong intuition that we should want to change those values. But the idea is to have T5 mlm here and rely on their number.  | 
||
| 
     | 
||
| input_token_ids = np.concatenate( | ||
| [ | ||
| elt | ||
| for start, end, sentinel_token in zip(spans_start[::2], spans_end[::2], sentinel_token_ids) | ||
| for elt in [sample[start: end], np.full((1,), sentinel_token, dtype=np.int32)] | ||
| ] + | ||
| [np.full((1,), sep_id, dtype=np.int32)] | ||
| ) | ||
| target_token_ids = np.concatenate( | ||
| [ | ||
| elt | ||
| for start, end, sentinel_token in zip(spans_start[1::2], spans_end[1::2], sentinel_token_ids) | ||
| for elt in [np.full((1,), sentinel_token, dtype=np.int32), sample[start: end]] | ||
| ] + | ||
| [np.full((1,), sep_id, dtype=np.int32)] | ||
| ) | ||
| 
     | 
||
| return { | ||
| 'input_tokens': input_token_ids, | ||
| 'target_tokens': target_token_ids | ||
| } | ||
| 
     | 
||
| 
     | 
||
| def compute_input_and_target_lengths(sequence_length, noise_density, mean_noise_span_length): | ||
| """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ . | ||
| Training parameters to avoid padding with random_spans_noise_mask. | ||
| When training a model with random_spans_noise_mask, we would like to set the other | ||
| training hyperparmeters in a way that avoids padding. | ||
| This function helps us compute these hyperparameters. | ||
| We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens, | ||
| and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens. | ||
| This function tells us the required number of tokens in the raw example (for split_tokens()) | ||
| as well as the length of the encoded targets. Note that this function assumes | ||
| the inputs and targets will have SEP appended and includes that in the reported length. | ||
| Args: | ||
| inputs_length: an integer - desired length of the tokenized inputs sequence | ||
| noise_density: a float | ||
| mean_noise_span_length: a float | ||
| Returns: | ||
| tokens_length: length of original text in tokens | ||
| targets_length: an integer - length in tokens of encoded targets sequence | ||
| """ | ||
| 
     | 
||
| def _tokens_length_to_inputs_length_targets_length(_tokens_length): | ||
| num_noise_tokens = int(round(_tokens_length * noise_density)) | ||
| num_nonnoise_tokens = _tokens_length - num_noise_tokens | ||
| _num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) | ||
| # inputs contain all nonnoise tokens, sentinels for all noise spans | ||
| # and one SEP token. | ||
| _input_length = num_nonnoise_tokens + _num_noise_spans + 1 | ||
| _output_length = num_noise_tokens + _num_noise_spans + 1 | ||
| return _input_length, _output_length, _num_noise_spans | ||
| 
     | 
||
| tokens_length = sequence_length | ||
| inputs_length, targets_length, num_noise_spans = _tokens_length_to_inputs_length_targets_length(tokens_length) | ||
| while inputs_length + targets_length > sequence_length: | ||
                
      
                  Muennighoff marked this conversation as resolved.
               
          
            Show resolved
            Hide resolved
         | 
||
| tokens_length -= 1 | ||
| inputs_length, targets_length, num_noise_spans = _tokens_length_to_inputs_length_targets_length(tokens_length) | ||
| 
     | 
||
| # tokens_length is the number of raw tokens we need to get | ||
| # inputs_length will be the input | ||
| # targets_length will be the target | ||
| # num_noise_spans is the number of spans we have to replace | ||
| return tokens_length, inputs_length, targets_length, num_noise_spans | ||
| 
     | 
||
| 
     | 
||
| def random_spans_noise_mask( | ||
| inputs_length, | ||
| targets_length, | ||
| num_noise_spans, | ||
| ): | ||
| 
     | 
||
| """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ . | ||
                
      
                  Muennighoff marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| Noise mask consisting of random spans of noise tokens. | ||
| The number of noise tokens and the number of noise spans and non-noise spans | ||
| are determined deterministically as follows: | ||
| num_noise_tokens = round(length * noise_density) | ||
| num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) | ||
| Spans alternate between non-noise and noise, beginning with non-noise. | ||
| Subject to the above restrictions, all masks are equally likely. | ||
| Args: | ||
| inputs_length: int32 scalar | ||
| targets_length: int32 scalar | ||
| num_noise_spans: int32 scalar | ||
| Returns: | ||
| a int8 tensor with shape [num_noise_spans] | ||
| a boolean tensor with shape [length] | ||
| """ | ||
| # # pick the lengths of the noise spans and the non-noise spans | ||
| num_noise_tokens = targets_length - num_noise_spans - 1 | ||
| num_nonnoise_tokens = inputs_length - num_noise_spans - 1 | ||
| number_of_raw_tokens = num_noise_tokens + num_nonnoise_tokens | ||
| 
     | 
||
| def _random_segmentation(num_items, num_segments): | ||
| """Partition a sequence of items randomly into non-empty segments. | ||
| Args: | ||
| num_items: an integer scalar > 0 | ||
| num_segments: an integer scalar in [1, num_items] | ||
| Returns: | ||
| a Tensor with shape [num_segments] containing positive integers that add | ||
| up to num_items | ||
| """ | ||
| mask_indices = np.arange(num_items - 1) < (num_segments - 1) | ||
| # TODO @thomasw21 handle random state correctly, ie synchronized across TP. | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This scares me a bit because TP-random states things are hard to debug but tbh we should just test asap to see if loss goes down at the expected rate. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes I need to double check that. I can have a go at it. Have forgotten about this TODO.  | 
||
| # we might not care as get_batch_pipe broadcasts data to all devices. | ||
| np.random.shuffle(mask_indices) | ||
| first_in_segment = np.pad(mask_indices, [[1, 0]], constant_values=0) | ||
| segment_id = np.cumsum(first_in_segment) | ||
| # count length of sub segments assuming that list is sorted | ||
| _, segment_length = np.unique(segment_id, return_counts=True) | ||
| return segment_length | ||
| 
     | 
||
| noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) | ||
| nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans) | ||
| 
     | 
||
| interleaved_span_lengths = np.reshape( | ||
| np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2] | ||
| ) | ||
| span_starts = np.concatenate([np.full((1,), 0, dtype=np.int32), np.cumsum(interleaved_span_lengths)[:-1]]) | ||
| span_start_indicator = np.zeros((number_of_raw_tokens,), dtype=np.int8) | ||
| span_start_indicator[span_starts] = True | ||
| span_num = np.cumsum(span_start_indicator) | ||
| is_noise = np.equal(span_num % 2, 1) | ||
| 
     | 
||
| return span_starts, is_noise | ||
Uh oh!
There was an error while loading. Please reload this page.