-
Notifications
You must be signed in to change notification settings - Fork 228
Mlm adaptation #287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mlm adaptation #287
Changes from 177 commits
d2c35fc
f977b85
fcfbf17
870dfd8
791bbd0
0f44b92
2ff0815
f0a79f6
eb416c7
c0bc21b
82e824c
7bb17ec
9929766
861c41f
fe95115
8ea5943
aa0d146
215e8cc
b6eef43
bfc73a5
1890f87
01392a9
923decb
4611d67
4356de3
f31c686
a3951e8
97b9a92
fe73a73
6a9cb75
469848f
e68283f
476ae94
72ff575
3647291
fcdc987
d6fbe78
c44daba
a2725d8
0e94245
b599ab6
626b0ae
6008937
c1524db
c59c061
e677e16
9ffaeb9
d0a6a2f
47fd987
4f377e8
5c0bf76
7c63e4b
871124c
55a593d
adb59ca
d71afb4
7b99bb7
922b09d
469a02d
15cb6a0
5b0bc17
0671c79
6db5c9b
8a58007
8b0bbc2
3d1b256
ce00fd9
3bcc50c
76960f7
229d661
55e3df7
05dea6d
661c8bb
97d3810
71388ee
b0f04d5
cd43a54
e0dc666
866cee1
0b56a7d
5bb512b
31d844f
1d21963
1429645
f5341f8
b05b175
59a6e32
ab76d49
0d8dfac
e629224
efcf50f
e5eb615
2eee807
5840a11
6d38f73
430fa6f
444314f
26c837d
feb023c
f30b9b1
0a9203a
672a866
3780e61
2130c31
26afe43
c1b9816
453822f
a62266a
02dda79
80331cb
350227d
d0eecd4
243cebe
da22e0b
083dce7
541e9d6
86bfc8a
e21a448
f47d678
415b8bc
79bd6f8
ba19fdf
d200f4d
102a461
e530440
2568039
e6b4120
fd7fe97
861fc7b
21c1984
14e8d0f
920343f
a68873d
5d43986
79e8c1a
786d252
9110520
7db34b9
d946515
bb4e656
2e7161d
00473e4
5992776
83f5dee
3235c2d
5449978
95c9851
9ff6172
451318f
edfaa19
5657083
1cee345
b4b87fc
5e80cc1
253e81f
1d8a5c0
e6036a0
408f16a
ae87552
e79c9a2
62ee550
a2e9ba8
64334a4
7a872c2
b6f02c5
86680bc
4b2d840
b935b85
9a74d69
64b1515
b210364
6398d1d
e0f7c92
6b92958
faf0b9e
0e3ee15
92070ce
ea69602
4dbe448
8f42790
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,372 @@ | ||
| """Non-Causal Mask Language Model Finetune Style dataset.""" | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from megatron import print_rank_0, get_tokenizer | ||
| from megatron.data.blendable_dataset import BlendableDataset | ||
| from megatron.data.dataset_utils import get_datasets_weights_and_num_samples | ||
| from megatron.data.dataset_utils import get_train_valid_test_split_, get_indexed_dataset_ | ||
| from megatron.data.gpt_dataset import GPTDataset | ||
|
|
||
|
|
||
| def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||
| train_valid_test_num_samples, | ||
| sequence_length, | ||
| noise_density, | ||
| mean_noise_span_length, | ||
| seed, | ||
| skip_warmup | ||
| ): | ||
| assert noise_density is not None | ||
| assert mean_noise_span_length is not None | ||
|
|
||
| if len(data_prefix) == 1: | ||
| return _build_train_valid_test_datasets( | ||
| data_prefix=data_prefix[0], | ||
| data_impl=data_impl, | ||
| splits_string=splits_string, | ||
| train_valid_test_num_samples=train_valid_test_num_samples, | ||
| sequence_length=sequence_length, | ||
| noise_density=noise_density, | ||
| mean_noise_span_length=mean_noise_span_length, | ||
| seed=seed, | ||
| skip_warmup=skip_warmup | ||
| ) | ||
| # Blending dataset. | ||
| # Parse the values. | ||
| output = get_datasets_weights_and_num_samples(data_prefix, | ||
| train_valid_test_num_samples) | ||
| prefixes, weights, datasets_train_valid_test_num_samples = output | ||
|
|
||
| # Build individual datasets. | ||
| train_datasets = [] | ||
| valid_datasets = [] | ||
| test_datasets = [] | ||
| for i in range(len(prefixes)): | ||
| train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( | ||
| data_prefix=prefixes[i], | ||
| data_impl=data_impl, | ||
| splits_string=splits_string, | ||
| train_valid_test_num_samples=datasets_train_valid_test_num_samples[i], | ||
| sequence_length=sequence_length, | ||
| noise_density=noise_density, | ||
| mean_noise_span_length=mean_noise_span_length, | ||
| seed=seed, | ||
| skip_warmup=skip_warmup | ||
| ) | ||
| if train_ds: | ||
| train_datasets.append(train_ds) | ||
| if valid_ds: | ||
| valid_datasets.append(valid_ds) | ||
| if test_ds: | ||
| test_datasets.append(test_ds) | ||
|
|
||
| # Blend. | ||
| blending_train_dataset = None | ||
| if train_datasets: | ||
| blending_train_dataset = BlendableDataset(train_datasets, weights) | ||
| blending_valid_dataset = None | ||
| if valid_datasets: | ||
| blending_valid_dataset = BlendableDataset(valid_datasets, weights) | ||
| blending_test_dataset = None | ||
| if test_datasets: | ||
| blending_test_dataset = BlendableDataset(test_datasets, weights) | ||
|
|
||
| return (blending_train_dataset, blending_valid_dataset, | ||
| blending_test_dataset) | ||
|
|
||
|
|
||
| def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||
| train_valid_test_num_samples, | ||
| sequence_length, | ||
| noise_density, | ||
| mean_noise_span_length, | ||
| seed, | ||
| skip_warmup): | ||
| """Build train, valid, and test datasets.""" | ||
|
|
||
|
|
||
| # Indexed dataset. | ||
| indexed_dataset = get_indexed_dataset_(data_prefix, | ||
| data_impl, | ||
| skip_warmup) | ||
|
|
||
| total_num_of_documents = indexed_dataset.sizes.shape[0] - 1 | ||
| splits = get_train_valid_test_split_(splits_string, total_num_of_documents) | ||
| # Print stats about the splits. | ||
| print_rank_0(' > dataset split:') | ||
|
|
||
| def print_split_stats(name, index): | ||
| print_rank_0(' {}:'.format(name)) | ||
| print_rank_0(' document indices in [{}, {}) total of {} ' | ||
| 'documents'.format(splits[index], splits[index + 1], | ||
| splits[index + 1] - splits[index])) | ||
| start_index = indexed_dataset.doc_idx[splits[index]] | ||
| end_index = indexed_dataset.doc_idx[splits[index + 1]] | ||
| print_rank_0(' sentence indices in [{}, {}) total of {} ' | ||
| 'sentences'.format(start_index, end_index, | ||
| end_index - start_index)) | ||
| print_split_stats('train', 0) | ||
| print_split_stats('validation', 1) | ||
| print_split_stats('test', 2) | ||
|
|
||
| def build_dataset(index, name): | ||
| dataset = None | ||
| if splits[index + 1] > splits[index]: | ||
| # Build the dataset accordingly. | ||
| documents = np.arange(start=splits[index], stop=splits[index + 1], | ||
| step=1, dtype=np.int32) | ||
| dataset = MLMDataset( | ||
| indexed_dataset=indexed_dataset, | ||
| documents=documents, | ||
| noise_density=noise_density, | ||
| mean_noise_span_length=mean_noise_span_length, | ||
| name=name, | ||
| data_prefix=data_prefix, | ||
| sequence_length=sequence_length, | ||
| num_samples=train_valid_test_num_samples[index], | ||
| seed=seed, | ||
| ) | ||
| return dataset | ||
|
|
||
| train_dataset = build_dataset(0, 'train') | ||
| valid_dataset = build_dataset(1, 'valid') | ||
| test_dataset = build_dataset(2, 'test') | ||
|
|
||
| return (train_dataset, valid_dataset, test_dataset) | ||
|
|
||
|
|
||
| class MLMDataset(torch.utils.data.Dataset): | ||
|
|
||
| def __init__( | ||
| self, | ||
| name, | ||
| indexed_dataset, | ||
| documents, | ||
| data_prefix, | ||
| sequence_length, | ||
| num_samples, | ||
| seed, | ||
| noise_density=0.15, | ||
| mean_noise_span_length=3 | ||
| ): | ||
|
|
||
| # Params to store. | ||
| self.name = name | ||
| self.seed = seed | ||
| self.sequence_length = sequence_length | ||
|
|
||
| # Dataset. | ||
| self.indexed_dataset = indexed_dataset | ||
|
|
||
| self.noise_density = noise_density | ||
| self.mean_noise_span_length = mean_noise_span_length | ||
| # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token. | ||
| # To ensure that the input length is `sequence_length`, we need to increase the maximum length | ||
| # according to `noise_density` and `mean_noise_span_length`. We can also define the label length accordingly. | ||
| number_of_raw_tokens, inputs_length, targets_length, num_noise_spans = compute_input_and_target_lengths( | ||
| # +1 is used so that we can compute the as autoregressive systems require us to add one more token. | ||
| sequence_length=self.sequence_length + 1, | ||
| noise_density=self.noise_density, | ||
| mean_noise_span_length=self.mean_noise_span_length | ||
| ) | ||
| self.number_of_raw_tokens = number_of_raw_tokens | ||
| self.inputs_length = inputs_length | ||
| self.targets_length = targets_length | ||
| self.num_noise_spans = num_noise_spans | ||
|
|
||
| # Build the samples mapping. | ||
| self._gpt_dataset = GPTDataset( | ||
| name=self.name, | ||
| data_prefix=data_prefix, | ||
| documents=documents, | ||
| indexed_dataset=self.indexed_dataset, | ||
| num_samples=num_samples, | ||
| seq_length=number_of_raw_tokens, | ||
| seed=seed | ||
| ) | ||
|
|
||
| # Vocab stuff. | ||
| tokenizer = get_tokenizer() | ||
| self.sep_id = tokenizer.sep | ||
| self.sentinel_token_ids = tokenizer.additional_special_tokens_ids | ||
| assert len(self.sentinel_token_ids) > 0, "Provide the argument --vocab-extra-ids 100 to the script" | ||
| assert len(self.sentinel_token_ids) >= self.num_noise_spans, "Not enough sentinel tokens, please add more" | ||
|
|
||
| def __len__(self): | ||
| return len(self.samples_mapping) | ||
|
|
||
| def __getitem__(self, idx): | ||
| if isinstance(idx, slice): | ||
| raise NotImplementedError | ||
|
|
||
| sample = self._gpt_dataset[idx]["text"] | ||
|
|
||
| return build_training_sample( | ||
| sample=sample, | ||
| inputs_length=self.inputs_length, | ||
| targets_length=self.targets_length, | ||
| num_noise_spans=self.num_noise_spans, | ||
| sep_id=self.sep_id, | ||
| all_sentinel_token_ids=self.sentinel_token_ids, | ||
| ) | ||
|
|
||
|
|
||
| def build_training_sample( | ||
| sample, | ||
| inputs_length, | ||
| targets_length, | ||
| num_noise_spans, | ||
| sep_id, | ||
| all_sentinel_token_ids, | ||
| ): | ||
| """Build training sample. | ||
|
|
||
| Arguments: | ||
| sample: int32 tensor | ||
| inputs_length: integer | ||
| targets_length: integer | ||
| num_noise_spans: integer | ||
| sep_id: integer | ||
| all_sentinel_token_ids: List[int] | ||
| Returns: | ||
| Dict with following keys: | ||
| - `input_tokens`: int32 tensor with as length input_length, | ||
| - `target_tokens`: int32 tensor with as length targets_length + 1, | ||
| """ | ||
|
|
||
| spans_start, mask_indices = random_spans_noise_mask( | ||
| inputs_length=inputs_length, | ||
| targets_length=targets_length, | ||
| num_noise_spans=num_noise_spans, | ||
| ) | ||
| spans_end = np.concatenate([ | ||
| spans_start[1:], np.full((1,), len(sample), dtype=np.int32)] | ||
| ) | ||
|
|
||
| sentinel_token_ids = all_sentinel_token_ids[:num_noise_spans] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given I wonder if it wouldn't be better to make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also have a strong intuition that we should want to change those values. But the idea is to have T5 mlm here and rely on their number. |
||
|
|
||
| input_token_ids = np.concatenate( | ||
| [ | ||
| elt | ||
| for start, end, sentinel_token in zip(spans_start[::2], spans_end[::2], sentinel_token_ids) | ||
| for elt in [sample[start: end], np.full((1,), sentinel_token, dtype=np.int32)] | ||
| ] + | ||
| [np.full((1,), sep_id, dtype=np.int32)] | ||
| ) | ||
| target_token_ids = np.concatenate( | ||
| [ | ||
| elt | ||
| for start, end, sentinel_token in zip(spans_start[1::2], spans_end[1::2], sentinel_token_ids) | ||
| for elt in [np.full((1,), sentinel_token, dtype=np.int32), sample[start: end]] | ||
| ] + | ||
| [np.full((1,), sep_id, dtype=np.int32)] | ||
| ) | ||
|
|
||
| return { | ||
| 'input_tokens': input_token_ids, | ||
| 'target_tokens': target_token_ids | ||
| } | ||
|
|
||
|
|
||
| def compute_input_and_target_lengths(sequence_length, noise_density, mean_noise_span_length): | ||
| """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ . | ||
| Training parameters to avoid padding with random_spans_noise_mask. | ||
| When training a model with random_spans_noise_mask, we would like to set the other | ||
| training hyperparmeters in a way that avoids padding. | ||
| This function helps us compute these hyperparameters. | ||
| The number of noise tokens and the number of noise spans and non-noise spans | ||
| are determined deterministically as follows: | ||
| num_noise_tokens = round(length * noise_density) | ||
| num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) | ||
| We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens, | ||
| and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens. | ||
| This function tells us the required number of tokens in the raw example (for split_tokens()) | ||
| as well as the length of the encoded targets. Note that this function assumes | ||
| the inputs and targets will have SEP appended and includes that in the reported length. | ||
| Args: | ||
| inputs_length: an integer - desired length of the tokenized inputs sequence | ||
| noise_density: a float | ||
| mean_noise_span_length: a float | ||
| Returns: | ||
| tokens_length: length of original text in tokens | ||
| targets_length: an integer - length in tokens of encoded targets sequence | ||
| """ | ||
|
|
||
| def _tokens_length_to_inputs_length_targets_length(_tokens_length): | ||
| num_noise_tokens = int(round(_tokens_length * noise_density)) | ||
| num_nonnoise_tokens = _tokens_length - num_noise_tokens | ||
| _num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) | ||
| # inputs contain all nonnoise tokens, sentinels for all noise spans and one SEP token. | ||
| _input_length = num_nonnoise_tokens + _num_noise_spans + 1 | ||
| _output_length = num_noise_tokens + _num_noise_spans + 1 | ||
| return _input_length, _output_length, _num_noise_spans | ||
|
|
||
| tokens_length = sequence_length | ||
| inputs_length, targets_length, num_noise_spans = _tokens_length_to_inputs_length_targets_length(tokens_length) | ||
| while inputs_length + targets_length > sequence_length: | ||
Muennighoff marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tokens_length -= 1 | ||
| inputs_length, targets_length, num_noise_spans = _tokens_length_to_inputs_length_targets_length(tokens_length) | ||
|
|
||
| # tokens_length is the number of raw tokens we need to get | ||
| # inputs_length will be the input | ||
| # targets_length will be the target | ||
| # num_noise_spans is the number of spans we have to replace | ||
| return tokens_length, inputs_length, targets_length, num_noise_spans | ||
|
|
||
|
|
||
| def random_spans_noise_mask( | ||
| inputs_length, | ||
| targets_length, | ||
| num_noise_spans, | ||
| ): | ||
|
|
||
| """This function is inspired from `random_spans_noise_mask <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ . | ||
| Noise mask consisting of random spans of noise tokens. | ||
| Spans alternate between non-noise and noise, beginning with non-noise. | ||
| Args: | ||
| inputs_length: int32 scalar | ||
| targets_length: int32 scalar | ||
| num_noise_spans: int32 scalar | ||
| Returns: | ||
| a int8 tensor with shape [num_noise_spans] | ||
| a boolean tensor with shape [length] | ||
| """ | ||
| # # pick the lengths of the noise spans and the non-noise spans | ||
| num_noise_tokens = targets_length - num_noise_spans - 1 | ||
| num_nonnoise_tokens = inputs_length - num_noise_spans - 1 | ||
| number_of_raw_tokens = num_noise_tokens + num_nonnoise_tokens | ||
|
|
||
| def _random_segmentation(num_items, num_segments): | ||
| """Partition a sequence of items randomly into non-empty segments. | ||
| Args: | ||
| num_items: an integer scalar > 0 | ||
| num_segments: an integer scalar in [1, num_items] | ||
| Returns: | ||
| a Tensor with shape [num_segments] containing positive integers that add | ||
| up to num_items | ||
| """ | ||
| mask_indices = np.arange(num_items - 1) < (num_segments - 1) | ||
| # TODO @thomasw21 handle random state correctly, ie synchronized across TP. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This scares me a bit because TP-random states things are hard to debug but tbh we should just test asap to see if loss goes down at the expected rate. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes I need to double check that. I can have a go at it. Have forgotten about this TODO. |
||
| # we might not care as get_batch_pipe broadcasts data to all devices. | ||
| np.random.shuffle(mask_indices) | ||
| first_in_segment = np.pad(mask_indices, [[1, 0]], constant_values=0) | ||
| segment_id = np.cumsum(first_in_segment) | ||
| # count length of sub segments assuming that list is sorted | ||
| _, segment_length = np.unique(segment_id, return_counts=True) | ||
| return segment_length | ||
|
|
||
| noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) | ||
| nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans) | ||
|
|
||
| interleaved_span_lengths = np.reshape( | ||
| np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2] | ||
| ) | ||
| span_starts = np.concatenate([np.full((1,), 0, dtype=np.int32), np.cumsum(interleaved_span_lengths)[:-1]]) | ||
| span_start_indicator = np.zeros((number_of_raw_tokens,), dtype=np.int8) | ||
| span_start_indicator[span_starts] = True | ||
| span_num = np.cumsum(span_start_indicator) | ||
| is_noise = np.equal(span_num % 2, 1) | ||
|
|
||
| return span_starts, is_noise | ||
Uh oh!
There was an error while loading. Please reload this page.