Skip to content

Commit 99c8fe0

Browse files
authored
Use padded vocab size in preprocessing scripts (#253)
1 parent b289eb8 commit 99c8fe0

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

tools/preprocess_data_dist.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def rank_files_write(args, dset, idx, encoder):
369369
try:
370370
# create data file for each rank
371371
if args.rank == 0:
372-
msg(f"Vocab size: {args.vocab_size}")
372+
msg(f"Vocab size: {args.padded_vocab_size}")
373373
msg(f"Output prefix: {args.output_prefix}")
374374
output_bin_files = {}
375375
output_idx_files = {}
@@ -378,7 +378,7 @@ def rank_files_write(args, dset, idx, encoder):
378378
filebase = get_filename(args, key, args.rank)
379379
output_bin_files[key] = data_file_path(filebase)
380380
output_idx_files[key] = index_file_path(filebase)
381-
best_dtype = best_fitting_dtype(args.vocab_size) if args.dataset_impl == "mmap" else None
381+
best_dtype = best_fitting_dtype(args.padded_vocab_size) if args.dataset_impl == "mmap" else None
382382
builders[key] = make_builder(output_bin_files[key],
383383
impl=args.dataset_impl,
384384
dtype=best_dtype)
@@ -515,7 +515,7 @@ def rank_files_merge_serial(args):
515515
filebase = get_filename(args, key)
516516
output_bin_files[key] = data_file_path(filebase)
517517
output_idx_files[key] = index_file_path(filebase)
518-
best_dtype = best_fitting_dtype(args.vocab_size) if args.dataset_impl == "mmap" else None
518+
best_dtype = best_fitting_dtype(args.padded_vocab_size) if args.dataset_impl == "mmap" else None
519519
builders[key] = make_builder(output_bin_files[key],
520520
impl=args.dataset_impl,
521521
dtype=best_dtype)
@@ -600,7 +600,6 @@ def main():
600600
nltk.download("punkt", quiet=True)
601601

602602
encoder = Encoder(args)
603-
args.vocab_size = encoder.tokenizer.vocab_size
604603

605604
# wait for all ranks before stopping timer
606605
args.distctx.barrier()

tools/preprocess_data_many_cores.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def process_samples(simple_queue, process_id, args, level, writer: Connection):
117117
output_filename = get_output_filename(args.output_prefix, key, level, process_id)
118118
output_bin_files[key] = data_file_path(output_filename)
119119
output_idx_files[key] = index_file_path(output_filename)
120-
best_dtype = best_fitting_dtype(args.vocab_size) if args.dataset_impl == "mmap" else None
120+
best_dtype = best_fitting_dtype(args.padded_vocab_size) if args.dataset_impl == "mmap" else None
121121
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
122122
impl=args.dataset_impl,
123123
dtype=best_dtype)
@@ -329,7 +329,7 @@ def main():
329329
output_filename = f"{args.output_prefix}_{key}_{level}"
330330
output_bin_files[key] = data_file_path(output_filename)
331331
output_idx_files[key] = index_file_path(output_filename)
332-
best_dtype = best_fitting_dtype(args.vocab_size) if args.dataset_impl == "mmap" else None
332+
best_dtype = best_fitting_dtype(args.padded_vocab_size) if args.dataset_impl == "mmap" else None
333333
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
334334
impl=args.dataset_impl,
335335
dtype=best_dtype)

0 commit comments

Comments
 (0)