|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import logging |
| 15 | +import math |
15 | 16 | import os |
16 | 17 | from copy import deepcopy |
| 18 | +from typing import Any |
17 | 19 | from unittest import mock |
18 | 20 |
|
19 | 21 | import pytest |
|
26 | 28 | from lightning.pytorch.tuner.lr_finder import _LRFinder |
27 | 29 | from lightning.pytorch.tuner.tuning import Tuner |
28 | 30 | from lightning.pytorch.utilities.exceptions import MisconfigurationException |
| 31 | +from lightning.pytorch.utilities.types import STEP_OUTPUT |
29 | 32 | from tests_pytorch.helpers.datamodules import ClassifDataModule |
30 | 33 | from tests_pytorch.helpers.runif import RunIf |
31 | 34 | from tests_pytorch.helpers.simple_models import ClassificationModel |
@@ -228,7 +231,7 @@ def __init__(self): |
228 | 231 | lr_finder = tuner.lr_find(model, early_stop_threshold=None) |
229 | 232 |
|
230 | 233 | assert lr_finder.suggestion() != 1e-3 |
231 | | - assert len(lr_finder.results["lr"]) == 100 |
| 234 | + assert len(lr_finder.results["lr"]) == len(lr_finder.results["loss"]) == 100 |
232 | 235 | assert lr_finder._total_batch_idx == 199 |
233 | 236 |
|
234 | 237 |
|
@@ -502,3 +505,35 @@ def configure_optimizers(self): |
502 | 505 |
|
503 | 506 | assert trainer.num_val_batches[0] == len(trainer.val_dataloaders) |
504 | 507 | assert trainer.num_val_batches[0] != num_lr_tuner_training_steps |
| 508 | + |
| 509 | + |
| 510 | +def test_lr_finder_training_step_none_output(tmpdir): |
| 511 | + # add some nans into the skipped steps (first 10) but also into the steps used to compute the lr |
| 512 | + none_steps = [5, 12, 17] |
| 513 | + |
| 514 | + class CustomBoringModel(BoringModel): |
| 515 | + def __init__(self): |
| 516 | + super().__init__() |
| 517 | + self.lr = 0.123 |
| 518 | + |
| 519 | + def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: |
| 520 | + if self.trainer.global_step in none_steps: |
| 521 | + return None |
| 522 | + |
| 523 | + return super().training_step(batch, batch_idx) |
| 524 | + |
| 525 | + seed_everything(1) |
| 526 | + model = CustomBoringModel() |
| 527 | + |
| 528 | + trainer = Trainer(default_root_dir=tmpdir) |
| 529 | + |
| 530 | + tuner = Tuner(trainer) |
| 531 | + # restrict number of steps for faster test execution |
| 532 | + # and disable early stopping to easily check expected number of lrs and losses |
| 533 | + lr_finder = tuner.lr_find(model=model, update_attr=True, num_training=20, early_stop_threshold=None) |
| 534 | + assert len(lr_finder.results["lr"]) == len(lr_finder.results["loss"]) == 20 |
| 535 | + assert torch.isnan(torch.tensor(lr_finder.results["loss"])[none_steps]).all() |
| 536 | + |
| 537 | + suggested_lr = lr_finder.suggestion() |
| 538 | + assert math.isfinite(suggested_lr) |
| 539 | + assert math.isclose(model.lr, suggested_lr) |
0 commit comments