Skip to content

Commit 63ce3c7

Browse files
committed
make sure temp checkpoints are cleaned up on failed tuning
1 parent 3d56296 commit 63ce3c7

File tree

2 files changed

+52
-44
lines changed

2 files changed

+52
-44
lines changed

src/lightning/pytorch/tuner/batch_size_scaling.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,24 +76,27 @@ def _scale_batch_size(
7676
if trainer.progress_bar_callback:
7777
trainer.progress_bar_callback.disable()
7878

79-
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)
80-
81-
if mode == "power":
82-
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
83-
elif mode == "binsearch":
84-
new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)
79+
try:
80+
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)
8581

86-
garbage_collection_cuda()
82+
if mode == "power":
83+
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
84+
elif mode == "binsearch":
85+
new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)
8786

88-
log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
87+
garbage_collection_cuda()
8988

90-
__scale_batch_restore_params(trainer, params)
89+
log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
90+
except Exception as e:
91+
raise e
92+
finally:
93+
__scale_batch_restore_params(trainer, params)
9194

92-
if trainer.progress_bar_callback:
93-
trainer.progress_bar_callback.enable()
95+
if trainer.progress_bar_callback:
96+
trainer.progress_bar_callback.enable()
9497

95-
trainer._checkpoint_connector.restore(ckpt_path)
96-
trainer.strategy.remove_checkpoint(ckpt_path)
98+
trainer._checkpoint_connector.restore(ckpt_path)
99+
trainer.strategy.remove_checkpoint(ckpt_path)
97100

98101
return new_size
99102

src/lightning/pytorch/tuner/lr_finder.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -257,40 +257,45 @@ def _lr_find(
257257
# Initialize lr finder object (stores results)
258258
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
259259

260-
# Configure optimizer and scheduler
261-
lr_finder._exchange_scheduler(trainer)
262-
263-
# Fit, lr & loss logged in callback
264-
_try_loop_run(trainer, params)
265-
266-
# Prompt if we stopped early
267-
if trainer.global_step != num_training + start_steps:
268-
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")
269-
270-
# Transfer results from callback to lr finder object
271-
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
272-
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose
273-
274-
__lr_finder_restore_params(trainer, params)
275-
276-
if trainer.progress_bar_callback:
277-
trainer.progress_bar_callback.enable()
278-
279-
# Update results across ranks
280-
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
281-
282-
# Restore initial state of model (this will also restore the original optimizer state)
283-
trainer._checkpoint_connector.restore(ckpt_path)
284-
trainer.strategy.remove_checkpoint(ckpt_path)
285-
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
286-
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
287-
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
288-
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
289-
trainer.fit_loop.setup_data()
260+
lr_finder_finished = False
261+
try:
262+
# Configure optimizer and scheduler
263+
lr_finder._exchange_scheduler(trainer)
264+
265+
# Fit, lr & loss logged in callback
266+
_try_loop_run(trainer, params)
267+
268+
# Prompt if we stopped early
269+
if trainer.global_step != num_training + start_steps:
270+
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")
271+
272+
# Transfer results from callback to lr finder object
273+
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
274+
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose
275+
276+
__lr_finder_restore_params(trainer, params)
277+
278+
if trainer.progress_bar_callback:
279+
trainer.progress_bar_callback.enable()
280+
281+
# Update results across ranks
282+
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
283+
lr_finder_finished = True
284+
except Exception as e:
285+
raise e
286+
finally:
287+
# Restore initial state of model (this will also restore the original optimizer state)
288+
trainer._checkpoint_connector.restore(ckpt_path)
289+
trainer.strategy.remove_checkpoint(ckpt_path)
290+
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
291+
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
292+
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
293+
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
294+
trainer.fit_loop.setup_data()
290295

291296
# Apply LR suggestion after restoring so it persists for the real training run
292297
# When used as a callback, the suggestion would otherwise be lost due to checkpoint restore
293-
if update_attr:
298+
if update_attr and lr_finder_finished:
294299
lr = lr_finder.suggestion()
295300
if lr is not None:
296301
# update the attribute on the LightningModule (e.g., lr or learning_rate)

0 commit comments

Comments
 (0)