@@ -257,40 +257,45 @@ def _lr_find(
257
257
# Initialize lr finder object (stores results)
258
258
lr_finder = _LRFinder (mode , min_lr , max_lr , num_training )
259
259
260
- # Configure optimizer and scheduler
261
- lr_finder ._exchange_scheduler (trainer )
262
-
263
- # Fit, lr & loss logged in callback
264
- _try_loop_run (trainer , params )
265
-
266
- # Prompt if we stopped early
267
- if trainer .global_step != num_training + start_steps :
268
- log .info (f"LR finder stopped early after { trainer .global_step } steps due to diverging loss." )
269
-
270
- # Transfer results from callback to lr finder object
271
- lr_finder .results .update ({"lr" : trainer .callbacks [0 ].lrs , "loss" : trainer .callbacks [0 ].losses })
272
- lr_finder ._total_batch_idx = trainer .fit_loop .total_batch_idx # for debug purpose
273
-
274
- __lr_finder_restore_params (trainer , params )
275
-
276
- if trainer .progress_bar_callback :
277
- trainer .progress_bar_callback .enable ()
278
-
279
- # Update results across ranks
280
- lr_finder .results = trainer .strategy .broadcast (lr_finder .results )
281
-
282
- # Restore initial state of model (this will also restore the original optimizer state)
283
- trainer ._checkpoint_connector .restore (ckpt_path )
284
- trainer .strategy .remove_checkpoint (ckpt_path )
285
- trainer .fit_loop .restarting = False # reset restarting flag as checkpoint restoring sets it to True
286
- trainer .fit_loop .epoch_loop .restarting = False # reset restarting flag as checkpoint restoring sets it to True
287
- trainer .fit_loop .epoch_loop .val_loop ._combined_loader = None
288
- trainer .fit_loop ._combined_loader = None # reset data fetcher to avoid issues with the next fit
289
- trainer .fit_loop .setup_data ()
260
+ lr_finder_finished = False
261
+ try :
262
+ # Configure optimizer and scheduler
263
+ lr_finder ._exchange_scheduler (trainer )
264
+
265
+ # Fit, lr & loss logged in callback
266
+ _try_loop_run (trainer , params )
267
+
268
+ # Prompt if we stopped early
269
+ if trainer .global_step != num_training + start_steps :
270
+ log .info (f"LR finder stopped early after { trainer .global_step } steps due to diverging loss." )
271
+
272
+ # Transfer results from callback to lr finder object
273
+ lr_finder .results .update ({"lr" : trainer .callbacks [0 ].lrs , "loss" : trainer .callbacks [0 ].losses })
274
+ lr_finder ._total_batch_idx = trainer .fit_loop .total_batch_idx # for debug purpose
275
+
276
+ __lr_finder_restore_params (trainer , params )
277
+
278
+ if trainer .progress_bar_callback :
279
+ trainer .progress_bar_callback .enable ()
280
+
281
+ # Update results across ranks
282
+ lr_finder .results = trainer .strategy .broadcast (lr_finder .results )
283
+ lr_finder_finished = True
284
+ except Exception as ex :
285
+ raise ex
286
+ finally :
287
+ # Restore initial state of model (this will also restore the original optimizer state)
288
+ trainer ._checkpoint_connector .restore (ckpt_path )
289
+ trainer .strategy .remove_checkpoint (ckpt_path )
290
+ trainer .fit_loop .restarting = False # reset restarting flag as checkpoint restoring sets it to True
291
+ trainer .fit_loop .epoch_loop .restarting = False # reset restarting flag as checkpoint restoring sets it to True
292
+ trainer .fit_loop .epoch_loop .val_loop ._combined_loader = None
293
+ trainer .fit_loop ._combined_loader = None # reset data fetcher to avoid issues with the next fit
294
+ trainer .fit_loop .setup_data ()
290
295
291
296
# Apply LR suggestion after restoring so it persists for the real training run
292
297
# When used as a callback, the suggestion would otherwise be lost due to checkpoint restore
293
- if update_attr :
298
+ if update_attr and lr_finder_finished :
294
299
lr = lr_finder .suggestion ()
295
300
if lr is not None :
296
301
# update the attribute on the LightningModule (e.g., lr or learning_rate)
0 commit comments