@@ -257,40 +257,45 @@ def _lr_find(
257257 # Initialize lr finder object (stores results)
258258 lr_finder = _LRFinder (mode , min_lr , max_lr , num_training )
259259
260- # Configure optimizer and scheduler
261- lr_finder ._exchange_scheduler (trainer )
262-
263- # Fit, lr & loss logged in callback
264- _try_loop_run (trainer , params )
265-
266- # Prompt if we stopped early
267- if trainer .global_step != num_training + start_steps :
268- log .info (f"LR finder stopped early after { trainer .global_step } steps due to diverging loss." )
269-
270- # Transfer results from callback to lr finder object
271- lr_finder .results .update ({"lr" : trainer .callbacks [0 ].lrs , "loss" : trainer .callbacks [0 ].losses })
272- lr_finder ._total_batch_idx = trainer .fit_loop .total_batch_idx # for debug purpose
273-
274- __lr_finder_restore_params (trainer , params )
275-
276- if trainer .progress_bar_callback :
277- trainer .progress_bar_callback .enable ()
278-
279- # Update results across ranks
280- lr_finder .results = trainer .strategy .broadcast (lr_finder .results )
281-
282- # Restore initial state of model (this will also restore the original optimizer state)
283- trainer ._checkpoint_connector .restore (ckpt_path )
284- trainer .strategy .remove_checkpoint (ckpt_path )
285- trainer .fit_loop .restarting = False # reset restarting flag as checkpoint restoring sets it to True
286- trainer .fit_loop .epoch_loop .restarting = False # reset restarting flag as checkpoint restoring sets it to True
287- trainer .fit_loop .epoch_loop .val_loop ._combined_loader = None
288- trainer .fit_loop ._combined_loader = None # reset data fetcher to avoid issues with the next fit
289- trainer .fit_loop .setup_data ()
260+ lr_finder_finished = False
261+ try :
262+ # Configure optimizer and scheduler
263+ lr_finder ._exchange_scheduler (trainer )
264+
265+ # Fit, lr & loss logged in callback
266+ _try_loop_run (trainer , params )
267+
268+ # Prompt if we stopped early
269+ if trainer .global_step != num_training + start_steps :
270+ log .info (f"LR finder stopped early after { trainer .global_step } steps due to diverging loss." )
271+
272+ # Transfer results from callback to lr finder object
273+ lr_finder .results .update ({"lr" : trainer .callbacks [0 ].lrs , "loss" : trainer .callbacks [0 ].losses })
274+ lr_finder ._total_batch_idx = trainer .fit_loop .total_batch_idx # for debug purpose
275+
276+ __lr_finder_restore_params (trainer , params )
277+
278+ if trainer .progress_bar_callback :
279+ trainer .progress_bar_callback .enable ()
280+
281+ # Update results across ranks
282+ lr_finder .results = trainer .strategy .broadcast (lr_finder .results )
283+ lr_finder_finished = True
284+ except Exception as ex :
285+ raise ex
286+ finally :
287+ # Restore initial state of model (this will also restore the original optimizer state)
288+ trainer ._checkpoint_connector .restore (ckpt_path )
289+ trainer .strategy .remove_checkpoint (ckpt_path )
290+ trainer .fit_loop .restarting = False # reset restarting flag as checkpoint restoring sets it to True
291+ trainer .fit_loop .epoch_loop .restarting = False # reset restarting flag as checkpoint restoring sets it to True
292+ trainer .fit_loop .epoch_loop .val_loop ._combined_loader = None
293+ trainer .fit_loop ._combined_loader = None # reset data fetcher to avoid issues with the next fit
294+ trainer .fit_loop .setup_data ()
290295
291296 # Apply LR suggestion after restoring so it persists for the real training run
292297 # When used as a callback, the suggestion would otherwise be lost due to checkpoint restore
293- if update_attr :
298+ if update_attr and lr_finder_finished :
294299 lr = lr_finder .suggestion ()
295300 if lr is not None :
296301 # update the attribute on the LightningModule (e.g., lr or learning_rate)
0 commit comments