Skip to content

Commit 58d9204

Browse files
SaulLuthomasw21
andauthored
add pad-vocab-size-to argument and tests (#255)
* init new test * test pad vocab size to * add logs * log to warning * change TP * fix loop * revert * remove hack size * this new test should pass * test not divisible by num tp * Revert "remove hack size" This reverts commit bcc6d8d. * Revert "Revert "remove hack size"" This reverts commit 8322f89. * Revert "test not divisible by num tp" This reverts commit 92614bf. * Revert "this new test should pass" This reverts commit 9e17a4f. * change info to warning * change to print * add print * test 2 * new print * woups * more * woups * comment * raise errors * woups * pad to save vocab size * simplify test * assert test raised * print error msg * check msg error * check error * woups * clean * simplify * remove unused print * add comment * add test multiple of tp size * add print * add check * clean * Update megatron/mpu/layers.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/tokenizer/tokenizer.py Co-authored-by: Thomas Wang <[email protected]> * chnage micro-batch-size * use tiny vocab * fix data dir * fix arg * change micro-batch-size * adept input ids * assertIn * change micro batch size * Fix test TP Co-authored-by: Thomas Wang <[email protected]> * unused var * add test make_vocab_size_divisible_by * fix test_tokenizer_vocab_size_multiple_of_tp_size test * Fix padded vocab size on preprocessing scripts (#257) * Add tokenizer options in preprocessing scripts * This should fix the TP issue? Co-authored-by: SaulLu <[email protected]> * documentation Co-authored-by: Thomas Wang <[email protected]>
1 parent 00f5f88 commit 58d9204

File tree

7 files changed

+161
-17
lines changed

7 files changed

+161
-17
lines changed

megatron/arguments.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,10 @@ def _add_network_size_args(parser):
369369
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
370370
help='Pad the vocab size to be divisible by this value.'
371371
'This is added for computational efficieny reasons.')
372+
group.add_argument('--pad-vocab-size-to', type=int, default=None,
373+
help='Pad the vocab size to this value.'
374+
'This value must be greater than the initial size of the tokenizer'
375+
', needs to be divisible by TP size and `make-vocab-size-divisible-by`.')
372376
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
373377
help='Layer norm epsilon.')
374378
group.add_argument('--apply-residual-connection-post-layernorm',

megatron/mpu/layers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ def __init__(self, num_embeddings, embedding_dim,
217217

218218

219219
def forward(self, input_):
220+
if torch.any(input_ >= self.num_embeddings):
221+
raise ValueError(f"There is an input id in the input that is greater than the highest possible input id.\nInput: {input_}\nnum_embeddings: {self.num_embeddings}")
222+
220223
if self.tensor_model_parallel_size > 1:
221224
# Build the mask.
222225
input_mask = (input_ < self.vocab_start_index) | \
@@ -225,7 +228,9 @@ def forward(self, input_):
225228
masked_input = input_.clone() - self.vocab_start_index
226229
masked_input[input_mask] = 0
227230
else:
231+
# input_ is garanted to be in the range [0:self.vocab_end_index - self.vocab_start_index] thanks to the first check
228232
masked_input = input_
233+
229234
# Get the embeddings.
230235
output_parallel = F.embedding(masked_input, self.weight,
231236
self.padding_idx, self.max_norm,

megatron/tokenizer/tokenizer.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,25 @@ def build_tokenizer(args):
6868

6969

7070
def _vocab_size_with_padding(orig_vocab_size, args):
71-
"""Pad vocab size so it is divisible by model parallel size and
72-
still having GPU friendly size."""
73-
74-
after = orig_vocab_size
75-
multiple = args.make_vocab_size_divisible_by * \
76-
args.tensor_model_parallel_size
77-
while (after % multiple) != 0:
78-
after += 1
71+
"""Apply the requested rules to change the size of the vocabulary"""
72+
if args.pad_vocab_size_to is not None:
73+
if args.pad_vocab_size_to < orig_vocab_size:
74+
raise ValueError(
75+
f"You asked to pad the vocabulary to {args.pad_vocab_size_to} when the initial vocabulary size is "
76+
f"{orig_vocab_size}. You can only pad to a higher value."
77+
)
78+
79+
if args.make_vocab_size_divisible_by is not None and (args.pad_vocab_size_to % args.make_vocab_size_divisible_by) != 0:
80+
raise ValueError(f"{args.pad_vocab_size_to} is not divisible by {args.make_vocab_size_divisible_by}")
81+
82+
after = args.pad_vocab_size_to
83+
else:
84+
# Pad vocab size so it is divisible by model parallel size and still having GPU friendly size.
85+
after = orig_vocab_size
86+
multiple = args.make_vocab_size_divisible_by * \
87+
args.tensor_model_parallel_size
88+
while (after % multiple) != 0:
89+
after += 1
7990
if args.rank == 0:
8091
print(' > padded vocab (size: {}) with {} dummy tokens '
8192
'(new size: {})'.format(

tests/test_tensor_parallel.py

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
class MegDSTestTP(TestCasePlus):
2626
def get_default_args(self):
2727
"""return a dictionary with key as argument name and value as additional arguments"""
28+
data_dir = f"{self.data_dir}/gpt2"
2829
return {
2930
# GPT_ARGS
3031
"--num-layers": "2",
@@ -39,8 +40,9 @@ def get_default_args(self):
3940
"--lr": "0.00015",
4041
"--min-lr": "1.0e-5",
4142
"--train-iters": "5000",
42-
"--tokenizer-type": "PretrainedFromHF",
43-
"--tokenizer-name-or-path": "gpt2",
43+
"--tokenizer-type": "GPT2BPETokenizer",
44+
"--merge-file": f"{data_dir}/gpt2-tiny-merges.txt",
45+
"--vocab-file": f"{data_dir}/gpt2-tiny-vocab.json",
4446
"--data-impl": "mmap",
4547
"--split": "949,50,1",
4648
"--distributed-backend": "nccl",
@@ -111,8 +113,6 @@ def create_model_inputs(tokens):
111113
initialize_megatron()
112114
args = get_args()
113115

114-
args.vocab_size = args.padded_vocab_size = 1024
115-
116116
tokenizer = get_tokenizer()
117117

118118
model, _, _ = setup_model_and_optimizer(gpt_model_provider)
@@ -141,7 +141,6 @@ def create_model_inputs(tokens):
141141
else:
142142
token_ids = torch.tensor(token_ids)
143143

144-
145144
model.micro_batches = 1
146145
model.set_batch_fn(create_model_inputs)
147146
# process batch
@@ -156,7 +155,7 @@ def create_model_inputs(tokens):
156155

157156
output = model.eval_batch(iter([token_ids]), compute_loss = False, reduce_output = None)[0]
158157

159-
output = gather_from_tensor_model_parallel_region(output)[..., :tokenizer.vocab_size]
158+
output = gather_from_tensor_model_parallel_region(output)
160159

161160
if save != None:
162161
args.save = save
@@ -169,6 +168,7 @@ def test_alibi_tp(self):
169168
cp_dir = self.get_auto_remove_tmp_dir()
170169

171170
command_args = self.get_default_args()
171+
command_args["--pad-vocab-size-to"] = "5120" # This is equal to 128 * 40 which is above the len of gp2-tiny vocabulary
172172
command_args["--position-embedding-type"] = "alibi"
173173
command_args["--tensor-model-parallel-size"] = "1"
174174

@@ -192,5 +192,107 @@ def test_alibi_tp(self):
192192
logging.getLogger().critical(output-output2)
193193
self.assertTrue(np.allclose(output,output2, atol=5e-3, rtol=0), "Different results when running with TP=1 and TP=2")
194194

195+
196+
197+
def test_embedding_matrix_tp(self):
198+
mp.set_start_method('spawn', force=True)
199+
cp_dir = self.get_auto_remove_tmp_dir()
200+
201+
command_args = self.get_default_args()
202+
command_args["--pad-vocab-size-to"] = "5120" # This is equal to 128 * 40 which is above the len of gp2-tiny vocabulary
203+
command_args["--seq-length"] = "4"
204+
command_args["--micro-batch-size"] = "2"
205+
tokens = [[5119, 0, 1, 5100],[0, 1, 5111, 5101]]
206+
207+
command_args["--tensor-model-parallel-size"] = "1"
208+
209+
pool = Pool(1)
210+
# tp_index, tp_size, command_args, token_ids, save, load
211+
result = pool.map(MegDSTestTP.infer_model, [((0, 1, command_args, tokens, cp_dir, None))])
212+
pool.close()
213+
pool.join()
214+
215+
output, _ = result[0]
216+
logging.getLogger().info("First done!")
217+
218+
command_args["--tensor-model-parallel-size"] = "2"
219+
220+
pool = Pool(2)
221+
result = pool.map(MegDSTestTP.infer_model, [((0, 2, command_args, tokens, None, cp_dir)), ((1, 2, command_args, tokens, None, cp_dir))])
222+
pool.close()
223+
pool.join()
224+
225+
output2, _ = result[0]
226+
227+
logging.getLogger().critical(output-output2)
228+
self.assertTrue(np.allclose(output,output2, atol=5e-3, rtol=0), "Different results when running with TP=1 and TP=2")
229+
230+
231+
def test_embedding_matrix_tp_with_invalid_tokens_ids(self):
232+
mp.set_start_method('spawn', force=True)
233+
234+
command_args = self.get_default_args()
235+
command_args["--pad-vocab-size-to"] = "5120" # This is equal to 128 * 40 which is above the len of gp2-tiny vocabulary
236+
command_args["--seq-length"] = "4"
237+
command_args["--micro-batch-size"] = "2"
238+
tokens = [[5120, 0, 1, 2],[0, 1, 3, 4]]
239+
240+
command_args["--tensor-model-parallel-size"] = "1"
241+
242+
pool = Pool(1)
243+
with pytest.raises(Exception) as exc_info:
244+
_ = pool.map(MegDSTestTP.infer_model, [((0, 1, command_args, tokens, None, None))])
245+
pool.close()
246+
pool.join()
247+
248+
self.assertIn("There is an input id in the input that is greater than the highest possible input id" , str(exc_info.value))
249+
250+
logging.getLogger().info("First done!")
251+
252+
command_args["--tensor-model-parallel-size"] = "2"
253+
254+
pool = Pool(2)
255+
with pytest.raises(Exception) as exc_info:
256+
_ = pool.map(MegDSTestTP.infer_model, [((0, 2, command_args, tokens, None, None)), ((1, 2, command_args, tokens, None, None))])
257+
pool.close()
258+
pool.join()
259+
260+
self.assertIn("There is an input id in the input that is greater than the highest possible input id", str(exc_info.value))
261+
262+
263+
def test_tokenizer_vocab_size_multiple_of_tp_size(self):
264+
mp.set_start_method('spawn', force=True)
265+
266+
command_args = self.get_default_args()
267+
command_args["--pad-vocab-size-to"] = "5121" # This is equal to 128 * 40 + 1 which is above the len of gp2-tiny vocabulary
268+
command_args["--micro-batch-size"] = "4"
269+
command_args["--tensor-model-parallel-size"] = "2"
270+
command_args["--make-vocab-size-divisible-by"] = "1"
271+
272+
pool = Pool(2)
273+
with pytest.raises(Exception) as exc_info:
274+
_ = pool.map(MegDSTestTP.infer_model, [((0, 2, command_args, None, None, None)), ((1, 2, command_args, None, None, None))])
275+
pool.close()
276+
pool.join()
277+
278+
self.assertEqual(str(exc_info.value), "5121 is not divisible by 2")
279+
280+
def test_tokenizer_raise_error_make_vocab_size_divisible_by(self):
281+
mp.set_start_method('spawn', force=True)
282+
283+
command_args = self.get_default_args()
284+
command_args["--pad-vocab-size-to"] = "5121" # This is equal to 128 * 40 + 1 which is above the len of gp2-tiny vocabulary
285+
command_args["--micro-batch-size"] = "4"
286+
287+
288+
pool = Pool(2)
289+
with pytest.raises(Exception) as exc_info:
290+
_ = pool.map(MegDSTestTP.infer_model, [((0, 2, command_args, None, None, None)), ((1, 2, command_args, None, None, None))])
291+
pool.close()
292+
pool.join()
293+
294+
self.assertEqual(str(exc_info.value), "5121 is not divisible by 128")
295+
296+
195297
if __name__ == '__main__':
196298
unittest.main()

tools/preprocess_data.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@ def get_args():
119119
help='Append an <eod> token to the end of a document.')
120120
group.add_argument("--tokenizer-name-or-path", type=str, default=None,
121121
help="Name or path of the huggingface tokenizer.")
122+
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
123+
help='Pad the vocab size to be divisible by this value.'
124+
'This is added for computational efficieny reasons.')
125+
group.add_argument('--pad-vocab-size-to', type=int, default=None,
126+
help='Pad the vocab size to be divisible by this value.'
127+
'Value of the size of the vocabulary of the tokenizer to reach. This value must be greater than'
128+
' the initial size of the tokenizer. If this argument is used the value of '
129+
'`make-vocab-size-divisible-by` will be ignored.')
122130

123131
group = parser.add_argument_group(title='output data')
124132
group.add_argument('--output-prefix', type=str, required=True,
@@ -140,7 +148,6 @@ def get_args():
140148

141149
# some default/dummy values for the tokenizer
142150
args.rank = 0
143-
args.make_vocab_size_divisible_by = 128
144151
args.tensor_model_parallel_size = 1
145152
args.vocab_extra_ids = 0
146153

tools/preprocess_data_dist.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,15 @@ def get_args():
167167
help='Path to binary output file without suffix')
168168
group.add_argument('--dataset-impl', type=str, default='mmap',
169169
choices=['lazy', 'cached', 'mmap'])
170+
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
171+
help='Pad the vocab size to be divisible by this value.'
172+
'This is added for computational efficieny reasons.')
173+
group.add_argument('--pad-vocab-size-to', type=int, default=None,
174+
help='Pad the vocab size to be divisible by this value.'
175+
'Value of the size of the vocabulary of the tokenizer to reach. This value must be greater than'
176+
' the initial size of the tokenizer. If this argument is used the value of '
177+
'`make-vocab-size-divisible-by` will be ignored.')
178+
170179

171180
group = parser.add_argument_group(title='runtime')
172181
group.add_argument('--torch-backend', type=str, default='gloo', choices=['gloo', 'mpi'],
@@ -198,7 +207,6 @@ def get_args():
198207
args.numranks = args.distctx.numranks
199208

200209
# some default/dummy values for the tokenizer
201-
args.make_vocab_size_divisible_by = 128
202210
args.tensor_model_parallel_size = 1
203211
args.vocab_extra_ids = 0
204212

tools/preprocess_data_many_cores.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,14 @@ def get_args():
185185
help='Append an <eod> token to the end of a document.')
186186
group.add_argument("--tokenizer-name-or-path", type=str, default=None,
187187
help="Name or path of the huggingface tokenizer.")
188+
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
189+
help='Pad the vocab size to be divisible by this value.'
190+
'This is added for computational efficieny reasons.')
191+
group.add_argument('--pad-vocab-size-to', type=int, default=None,
192+
help='Pad the vocab size to be divisible by this value.'
193+
'Value of the size of the vocabulary of the tokenizer to reach. This value must be greater than'
194+
' the initial size of the tokenizer. If this argument is used the value of '
195+
'`make-vocab-size-divisible-by` will be ignored.')
188196

189197
group = parser.add_argument_group(title='output data')
190198
group.add_argument('--output-prefix', type=str, required=True,
@@ -206,7 +214,6 @@ def get_args():
206214

207215
# some default/dummy values for the tokenizer
208216
args.rank = 0
209-
args.make_vocab_size_divisible_by = 128
210217
args.tensor_model_parallel_size = 1
211218
args.vocab_extra_ids = 0
212219

0 commit comments

Comments
 (0)