Skip to content

Commit f802317

Browse files
committed
Add UL2 pretraining for T5 model
1 parent 006c4e9 commit f802317

File tree

4 files changed

+392
-1
lines changed

4 files changed

+392
-1
lines changed

megatron/arguments.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def parse_args(extra_args_provider=None, defaults={},
4949
parser = _add_autoresume_args(parser)
5050
parser = _add_biencoder_args(parser)
5151
parser = _add_vit_args(parser)
52+
parser = _add_ul2_args(parser)
5253
parser = _add_logging_args(parser)
5354
parser = _add_zero_args(parser)
5455
parser = _add_memoryopt_args(parser)
@@ -1024,6 +1025,39 @@ def _add_vit_args(parser):
10241025
return parser
10251026

10261027

1028+
def _add_ul2_args(parser):
1029+
group = parser.add_argument_group(title="UL2")
1030+
1031+
group.add_argument('--ul2-denoiser-ratios', nargs='+', type=float,
1032+
default=None,
1033+
help='Probability of each denoising objective to be '
1034+
'selected. Uniform distribution by default.')
1035+
group.add_argument('--ul2-denoisers', nargs='+', type=str,
1036+
default=['R', 'R', 'S', 'X', 'X', 'X', 'X'],
1037+
choices=['R', 'S', 'X'],
1038+
help='What type of UL2 denoising objective the other '
1039+
'UL2 configurations refer to.')
1040+
group.add_argument('--ul2-mean-span-lengths', nargs='+', type=float,
1041+
default=[3, 8, 0.25, 3, 8, 64, 64],
1042+
help='Mean length for sampling span lengths. '
1043+
'Numbers < 1 indicate a mean length of the sequence '
1044+
'length times that number.')
1045+
group.add_argument('--ul2-mask-ratios', nargs='+', type=float,
1046+
default=[0.15, 0.15, 0.25, 0.5, 0.5, 0.15, 0.5],
1047+
help='Ratio of masked token in the full sequence.')
1048+
group.add_argument('--ul2-r-denoiser-token', type=str, default='[R]',
1049+
help='What token to prepend for the UL2 R-denoising '
1050+
'objective.')
1051+
group.add_argument('--ul2-s-denoiser-token', type=str, default='[S]',
1052+
help='What token to prepend for the UL2 S-denoising '
1053+
'objective.')
1054+
group.add_argument('--ul2-x-denoiser-token', type=str, default='[X]',
1055+
help='What token to prepend for the UL2 X-denoising '
1056+
'objective.')
1057+
1058+
return parser
1059+
1060+
10271061
def _add_zero_args(parser):
10281062
"""Text generate arguments."""
10291063

megatron/data/dataset_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@
3838
DSET_TYPE_BERT = 'standard_bert'
3939
DSET_TYPE_ICT = 'ict'
4040
DSET_TYPE_T5 = 't5'
41+
DSET_TYPE_UL2 = 'ul2'
4142

42-
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
43+
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5, DSET_TYPE_UL2]
4344

4445

4546
class SamplingStyle(Enum):
@@ -553,6 +554,7 @@ def build_dataset(index, name):
553554
from megatron.data.bert_dataset import BertDataset
554555
from megatron.data.ict_dataset import ICTDataset
555556
from megatron.data.t5_dataset import T5Dataset
557+
from megatron.data.ul2_dataset import UL2Dataset
556558
dataset = None
557559
if splits[index + 1] > splits[index]:
558560
# Get the pointer to the original doc-idx so we can set it later.
@@ -591,6 +593,23 @@ def build_dataset(index, name):
591593
short_seq_prob=short_seq_prob,
592594
**kwargs
593595
)
596+
elif dataset_type == DSET_TYPE_UL2:
597+
args = get_args()
598+
dataset = UL2Dataset(
599+
indexed_dataset=indexed_dataset,
600+
denoiser_ratios=args.ul2_denoiser_ratios,
601+
denoisers=args.ul2_denoisers,
602+
mean_span_lengths=args.ul2_mean_span_lengths,
603+
mask_ratios=args.ul2_mask_ratios,
604+
denoiser_tokens={
605+
'R': args.ul2_r_denoiser_token,
606+
'S': args.ul2_s_denoiser_token,
607+
'X': args.ul2_x_denoiser_token,
608+
},
609+
max_seq_length_dec=max_seq_length_dec,
610+
short_seq_prob=short_seq_prob,
611+
**kwargs,
612+
)
594613
elif dataset_type == DSET_TYPE_BERT:
595614
dataset = BertDataset(
596615
indexed_dataset=indexed_dataset,

megatron/data/ul2_dataset.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# coding=utf-8
2+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""UL2-style dataset."""
17+
18+
import numpy as np
19+
20+
from megatron import get_tokenizer
21+
from megatron.data.dataset_utils import (
22+
create_masked_lm_predictions,
23+
get_samples_mapping,
24+
SamplingStyle
25+
)
26+
from megatron.data.t5_dataset import pad_and_convert_to_numpy, T5Dataset
27+
28+
29+
class UL2Dataset(T5Dataset):
30+
31+
def __init__(self, name, indexed_dataset, data_prefix,
32+
num_epochs, max_num_samples, denoiser_ratios,
33+
denoisers, mean_span_lengths, mask_ratios,
34+
denoiser_tokens, max_seq_length, max_seq_length_dec,
35+
short_seq_prob, seed):
36+
37+
if denoiser_ratios is None:
38+
# Uniform distribution by default.
39+
denoiser_ratios = [1 / len(denoisers)] * len(denoisers)
40+
41+
assert (
42+
len(denoiser_ratios) == len(denoisers)
43+
== len(mean_span_lengths) == len(mask_ratios)
44+
), (
45+
'some UL2 configurations do not correspond to the amount of '
46+
'denoising objectives'
47+
)
48+
49+
super().__init__(name, indexed_dataset, data_prefix,
50+
num_epochs, max_num_samples, None,
51+
max_seq_length, max_seq_length_dec,
52+
short_seq_prob, seed)
53+
54+
# Params to store.
55+
self.denoiser_ratios = [
56+
denoiser_ratio / sum(denoiser_ratios)
57+
for denoiser_ratio in denoiser_ratios
58+
]
59+
self.denoisers = [denoiser.upper() for denoiser in denoisers]
60+
self.mean_span_lengths = mean_span_lengths
61+
self.mask_ratios = mask_ratios
62+
63+
# Vocab stuff.
64+
tokenizer = get_tokenizer()
65+
# Remove CLS token because we don't need it.
66+
del self.cls_id
67+
self.cls_ids = {
68+
denoiser: tokenizer.vocab[token]
69+
for (denoiser, token) in denoiser_tokens.items()
70+
}
71+
# cls_token = self.vocab_id_to_token_dict[tokenizer.cls]
72+
# if cls_token not in self.cls_ids:
73+
# self.cls_ids[cls_token] = tokenizer.cls
74+
75+
# Filter out denoiser tokens.
76+
self.sentinel_tokens = [
77+
token
78+
for token in tokenizer.additional_special_tokens_ids
79+
if token not in self.cls_ids.values()
80+
]
81+
assert len(self.sentinel_tokens) > 0, \
82+
"Provide the argument --vocab-extra-ids 100 to the script"
83+
84+
def __getitem__(self, idx):
85+
86+
start_index, end_index, seq_length = self.samples_mapping[idx]
87+
sample = []
88+
for index in range(start_index, end_index):
89+
sample.append(self.indexed_dataset[index])
90+
# Note that this rng state should be numpy and not python since
91+
# python randint is inclusive whereas the numpy one is exclusive.
92+
np_rng = np.random.RandomState(seed=(self.seed + idx))
93+
return build_training_sample(sample, seq_length,
94+
self.max_seq_length, # needed for padding
95+
self.max_seq_length_dec,
96+
self.vocab_id_list,
97+
self.vocab_id_to_token_dict,
98+
self.cls_ids, self.sep_id,
99+
self.mask_id, self.pad_id,
100+
self.denoiser_ratios, self.denoisers,
101+
self.mean_span_lengths, self.mask_ratios,
102+
np_rng,
103+
self.bos_id, self.eos_id,
104+
self.sentinel_tokens)
105+
106+
107+
def build_training_sample(sample, target_seq_length,
108+
max_seq_length, max_seq_length_dec,
109+
vocab_id_list, vocab_id_to_token_dict,
110+
cls_ids, sep_id, mask_id, pad_id,
111+
denoiser_ratios, denoisers,
112+
mean_span_lengths, mask_ratios,
113+
np_rng, bos_id=None,
114+
eos_id=None, sentinel_tokens=None):
115+
"""Build training sample.
116+
117+
Arguments:
118+
sample: A list of sentences in which each sentence is a list token ids.
119+
target_seq_length: Desired sequence length.
120+
max_seq_length: Maximum length of the sequence. All values are padded to
121+
this length.
122+
vocab_id_list: List of vocabulary ids. Used to pick a random id.
123+
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
124+
cls_ids: Start of example ids.
125+
sep_id: Separator id.
126+
mask_id: Mask token id.
127+
pad_id: Padding token id.
128+
denoiser_ratios: Probability of each denoising objective to be selected.
129+
denoisers: What type of UL2 denoising objective the other UL2
130+
configurations refer to.
131+
mean_span_lengths: Mean length for sampling span lengths. Numbers < 1
132+
indicate a mean length of the sequence length times that number.
133+
mask_ratios: Ratio of masked token in the full sequence.
134+
np_rng: Random number genenrator. Note that this rng state should be
135+
numpy and not python since python randint is inclusive for
136+
the opper bound whereas the numpy one is exclusive.
137+
bos_id: start of decoder example id
138+
eos_id: end of generation id
139+
sentinel_tokens: unique value to be substituted for every replaced span
140+
"""
141+
142+
assert target_seq_length <= max_seq_length
143+
144+
# flatten sentences into one list
145+
tokens = [token for sentence in sample for token in sentence]
146+
147+
# Truncate to `target_sequence_length`.
148+
max_num_tokens = target_seq_length
149+
truncated = len(tokens) > max_num_tokens
150+
tokens = tokens[:max_num_tokens]
151+
152+
# Denoiser selection
153+
denoiser_index = np_rng.choice(np.arange(len(denoisers)), p=denoiser_ratios)
154+
denoiser = denoisers[denoiser_index]
155+
masked_lm_prob = mask_ratios[denoiser_index]
156+
mean_ngrams = mean_span_lengths[denoiser_index]
157+
if mean_ngrams < 1:
158+
mean_ngrams = round(len(tokens) * mean_ngrams)
159+
max_ngrams = mean_ngrams * 2 - 1
160+
161+
# Prepend objective token.
162+
cls_id = cls_ids.get(denoiser)
163+
if cls_id is None:
164+
raise ValueError('unknown denoiser')
165+
tokens = [cls_id] + tokens
166+
167+
# Masking.
168+
max_predictions_per_seq = masked_lm_prob * len(tokens)
169+
if denoiser == 'R' or denoiser == 'X':
170+
sampling_style = SamplingStyle.NORMAL
171+
prefix_lm = False
172+
elif denoiser == 'S':
173+
sampling_style = SamplingStyle.UNIFORM
174+
prefix_lm = True
175+
else:
176+
raise ValueError('unknown denoiser')
177+
(
178+
tokens, masked_positions, masked_labels, _, masked_spans,
179+
) = create_masked_lm_predictions(
180+
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
181+
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng,
182+
max_ngrams=max_ngrams, masking_style="t5",
183+
sampling_style=sampling_style, prefix_lm=prefix_lm,
184+
)
185+
186+
# Padding.
187+
tokens_enc, tokens_dec_in, labels, enc_mask, \
188+
dec_mask, enc_dec_mask, loss_mask \
189+
= pad_and_convert_to_numpy(tokens, masked_positions,
190+
masked_labels, pad_id, max_seq_length,
191+
max_seq_length_dec, masked_spans,
192+
bos_id, eos_id, sentinel_tokens)
193+
194+
train_sample = {
195+
'text_enc': tokens_enc,
196+
'text_dec': tokens_dec_in,
197+
'labels': labels,
198+
'loss_mask': loss_mask,
199+
'truncated': int(truncated),
200+
'enc_mask': enc_mask,
201+
'dec_mask': dec_mask,
202+
'enc_dec_mask': enc_dec_mask,
203+
}
204+
return train_sample

0 commit comments

Comments
 (0)