Skip to content

Commit 92d689e

Browse files
NikolasWolkelexierule
authored andcommitted
Support None from training_step in LRFinder (#18129)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 28c401c)
1 parent 9cf52ec commit 92d689e

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

src/lightning/pytorch/tuner/lr_finder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,13 @@ def on_train_batch_end(
395395
if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
396396
return
397397

398+
# _AutomaticOptimization.run turns None STEP_OUTPUT into an empty dict
399+
if not outputs:
400+
# need to add an element, because we also added one element to lrs in on_train_batch_start
401+
# so add nan, because they are not considered when computing the suggestion
402+
self.losses.append(float("nan"))
403+
return
404+
398405
if self.progress_bar:
399406
self.progress_bar.update()
400407

src/lightning/pytorch/tuner/tuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def lr_find(
114114
max_lr: float = 1,
115115
num_training: int = 100,
116116
mode: str = "exponential",
117-
early_stop_threshold: float = 4.0,
117+
early_stop_threshold: Optional[float] = 4.0,
118118
update_attr: bool = True,
119119
attr_name: str = "",
120120
) -> Optional["pl.tuner.lr_finder._LRFinder"]:

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
import math
1516
import os
1617
from copy import deepcopy
18+
from typing import Any
1719
from unittest import mock
1820

1921
import pytest
@@ -26,6 +28,7 @@
2628
from lightning.pytorch.tuner.lr_finder import _LRFinder
2729
from lightning.pytorch.tuner.tuning import Tuner
2830
from lightning.pytorch.utilities.exceptions import MisconfigurationException
31+
from lightning.pytorch.utilities.types import STEP_OUTPUT
2932
from tests_pytorch.helpers.datamodules import ClassifDataModule
3033
from tests_pytorch.helpers.runif import RunIf
3134
from tests_pytorch.helpers.simple_models import ClassificationModel
@@ -228,7 +231,7 @@ def __init__(self):
228231
lr_finder = tuner.lr_find(model, early_stop_threshold=None)
229232

230233
assert lr_finder.suggestion() != 1e-3
231-
assert len(lr_finder.results["lr"]) == 100
234+
assert len(lr_finder.results["lr"]) == len(lr_finder.results["loss"]) == 100
232235
assert lr_finder._total_batch_idx == 199
233236

234237

@@ -502,3 +505,35 @@ def configure_optimizers(self):
502505

503506
assert trainer.num_val_batches[0] == len(trainer.val_dataloaders)
504507
assert trainer.num_val_batches[0] != num_lr_tuner_training_steps
508+
509+
510+
def test_lr_finder_training_step_none_output(tmpdir):
511+
# add some nans into the skipped steps (first 10) but also into the steps used to compute the lr
512+
none_steps = [5, 12, 17]
513+
514+
class CustomBoringModel(BoringModel):
515+
def __init__(self):
516+
super().__init__()
517+
self.lr = 0.123
518+
519+
def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
520+
if self.trainer.global_step in none_steps:
521+
return None
522+
523+
return super().training_step(batch, batch_idx)
524+
525+
seed_everything(1)
526+
model = CustomBoringModel()
527+
528+
trainer = Trainer(default_root_dir=tmpdir)
529+
530+
tuner = Tuner(trainer)
531+
# restrict number of steps for faster test execution
532+
# and disable early stopping to easily check expected number of lrs and losses
533+
lr_finder = tuner.lr_find(model=model, update_attr=True, num_training=20, early_stop_threshold=None)
534+
assert len(lr_finder.results["lr"]) == len(lr_finder.results["loss"]) == 20
535+
assert torch.isnan(torch.tensor(lr_finder.results["loss"])[none_steps]).all()
536+
537+
suggested_lr = lr_finder.suggestion()
538+
assert math.isfinite(suggested_lr)
539+
assert math.isclose(model.lr, suggested_lr)

0 commit comments

Comments
 (0)