File tree Expand file tree Collapse file tree 1 file changed +6
-5
lines changed Expand file tree Collapse file tree 1 file changed +6
-5
lines changed Original file line number Diff line number Diff line change @@ -241,12 +241,8 @@ def train_coref_model(
241241
242242 # Model saving
243243 # ------------
244- if conll_f1 > best_f1 or best_f1 == 0 :
245- best_model = copy .deepcopy (model ).to ("cpu" )
246- best_f1 = conll_f1
247- if not model_save_dir is None :
248- model .save_pretrained (os .path .join (model_save_dir , "model" ))
249244 if not model_save_dir is None :
245+ os .makedirs (model_save_dir , exist_ok = True )
250246 _save_train_checkpoint (
251247 os .path .join (model_save_dir , "checkpoint.pth" ),
252248 model ,
@@ -255,5 +251,10 @@ def train_coref_model(
255251 bert_lr ,
256252 task_lr ,
257253 )
254+ if conll_f1 > best_f1 or best_f1 == 0 :
255+ best_model = copy .deepcopy (model ).to ("cpu" )
256+ best_f1 = conll_f1
257+ if not model_save_dir is None :
258+ model .save_pretrained (os .path .join (model_save_dir , "model" ))
258259
259260 return best_model
You can’t perform that action at this time.
0 commit comments