Skip to content

Commit 6b14e04

Browse files
committed
training: ensure model_save_dir exists
1 parent ccf9dd0 commit 6b14e04

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

tibert/train.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,8 @@ def train_coref_model(
241241

242242
# Model saving
243243
# ------------
244-
if conll_f1 > best_f1 or best_f1 == 0:
245-
best_model = copy.deepcopy(model).to("cpu")
246-
best_f1 = conll_f1
247-
if not model_save_dir is None:
248-
model.save_pretrained(os.path.join(model_save_dir, "model"))
249244
if not model_save_dir is None:
245+
os.makedirs(model_save_dir, exist_ok=True)
250246
_save_train_checkpoint(
251247
os.path.join(model_save_dir, "checkpoint.pth"),
252248
model,
@@ -255,5 +251,10 @@ def train_coref_model(
255251
bert_lr,
256252
task_lr,
257253
)
254+
if conll_f1 > best_f1 or best_f1 == 0:
255+
best_model = copy.deepcopy(model).to("cpu")
256+
best_f1 = conll_f1
257+
if not model_save_dir is None:
258+
model.save_pretrained(os.path.join(model_save_dir, "model"))
258259

259260
return best_model

0 commit comments

Comments
 (0)