Skip to content

Commit f5d8bd1

Browse files
committed
TrainerConfig.save_checkpoint_upon_crash
1 parent 949af4c commit f5d8bd1

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

alf/algorithms/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(self,
4545
sync_progress_to_envs=False,
4646
num_checkpoints=10,
4747
confirm_checkpoint_upon_crash=True,
48+
save_checkpoint_upon_crash=False,
4849
no_thread_env_for_conf=False,
4950
evaluate=False,
5051
num_evals=None,
@@ -207,6 +208,8 @@ def __init__(self,
207208
num_checkpoints (int): how many checkpoints to save for the training
208209
confirm_checkpoint_upon_crash (bool): whether to prompt for whether
209210
do checkpointing after crash.
211+
save_checkpoint_upon_crash (bool): whether to do checkpointing after
212+
crash.
210213
no_thread_env_for_conf (bool): not to create an unwrapped env for
211214
the purpose of showing operative configurations. If True, no
212215
``ThreadEnvironment`` will ever be created, regardless of the
@@ -401,6 +404,7 @@ def __init__(self,
401404
self.sync_progress_to_envs = sync_progress_to_envs
402405
self.num_checkpoints = num_checkpoints
403406
self.confirm_checkpoint_upon_crash = confirm_checkpoint_upon_crash
407+
self.save_checkpoint_upon_crash = save_checkpoint_upon_crash
404408
self.no_thread_env_for_conf = no_thread_env_for_conf
405409
self.evaluate = evaluate
406410
self.num_evals = num_evals

alf/trainers/policy_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,11 @@ def train(self):
377377
self._save_checkpoint()
378378
checkpoint_saved = True
379379
finally:
380-
if (self._config.confirm_checkpoint_upon_crash
380+
if (self._config.save_checkpoint_upon_crash
381381
and not checkpoint_saved and self._rank <= 0):
382+
self._save_checkpoint()
383+
elif (self._config.confirm_checkpoint_upon_crash
384+
and not checkpoint_saved and self._rank <= 0):
382385
# Prompts for checkpoint only when running single process
383386
# training (rank is -1) or master process of DDP training (rank
384387
# is 0).

0 commit comments

Comments
 (0)