This repository was archived by the owner on Jan 15, 2024. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -53,9 +53,7 @@ def main(args):
5353 random .shuffle (fnames )
5454 num_files = len (fnames )
5555 num_out_files = min (args .num_out_files , num_files )
56- file_volume = math .ceil (num_files / num_out_files )
57- splited_files = np .array_split (fnames , file_volume )
58- num_out_files = len (splited_files )
56+ splited_files = np .array_split (fnames , num_out_files )
5957 output_files = [os .path .join (
6058 args .output , "owt-pretrain-record-{}.npz" .format (str (i ).zfill (4 ))) for i in range (num_out_files )]
6159 print ("All preprocessed features will be saved in {} npz files" .format (num_out_files ))
Original file line number Diff line number Diff line change @@ -472,9 +472,11 @@ def train(args):
472472 train_end_time - train_start_time ))
473473 if writer is not None :
474474 writer .close ()
475- model_name = args .model_name .replace ('google' , 'gluon' )
476- save_dir = os .path .join (args .output_dir , model_name )
477- final_save (model , save_dir , tokenizer )
475+
476+ if local_rank == 0 :
477+ model_name = args .model_name .replace ('google' , 'gluon' )
478+ save_dir = os .path .join (args .output_dir , model_name )
479+ final_save (model , save_dir , tokenizer )
478480
479481# TODO(zheyuye), Directly implement a metric for weighted accuracy
480482
You can’t perform that action at this time.
0 commit comments