@@ -369,7 +369,7 @@ def rank_files_write(args, dset, idx, encoder):
369369 try :
370370 # create data file for each rank
371371 if args .rank == 0 :
372- msg (f"Vocab size: { args .vocab_size } " )
372+ msg (f"Vocab size: { args .padded_vocab_size } " )
373373 msg (f"Output prefix: { args .output_prefix } " )
374374 output_bin_files = {}
375375 output_idx_files = {}
@@ -378,7 +378,7 @@ def rank_files_write(args, dset, idx, encoder):
378378 filebase = get_filename (args , key , args .rank )
379379 output_bin_files [key ] = data_file_path (filebase )
380380 output_idx_files [key ] = index_file_path (filebase )
381- best_dtype = best_fitting_dtype (args .vocab_size ) if args .dataset_impl == "mmap" else None
381+ best_dtype = best_fitting_dtype (args .padded_vocab_size ) if args .dataset_impl == "mmap" else None
382382 builders [key ] = make_builder (output_bin_files [key ],
383383 impl = args .dataset_impl ,
384384 dtype = best_dtype )
@@ -515,7 +515,7 @@ def rank_files_merge_serial(args):
515515 filebase = get_filename (args , key )
516516 output_bin_files [key ] = data_file_path (filebase )
517517 output_idx_files [key ] = index_file_path (filebase )
518- best_dtype = best_fitting_dtype (args .vocab_size ) if args .dataset_impl == "mmap" else None
518+ best_dtype = best_fitting_dtype (args .padded_vocab_size ) if args .dataset_impl == "mmap" else None
519519 builders [key ] = make_builder (output_bin_files [key ],
520520 impl = args .dataset_impl ,
521521 dtype = best_dtype )
@@ -600,7 +600,6 @@ def main():
600600 nltk .download ("punkt" , quiet = True )
601601
602602 encoder = Encoder (args )
603- args .vocab_size = encoder .tokenizer .vocab_size
604603
605604 # wait for all ranks before stopping timer
606605 args .distctx .barrier ()
0 commit comments