Skip to content

Commit 8821df0

Browse files
authored
Merge branch 'master' into patch-3
2 parents f906648 + f58a176 commit 8821df0

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

docs/source-pytorch/cli/lightning_cli_intermediate_2.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,10 @@ If the scheduler you want needs other arguments, add them via the CLI (no need t
201201

202202
.. code:: bash
203203
204-
python main.py fit --optimizer=Adam --lr_scheduler=ReduceLROnPlateau --lr_scheduler.monitor=epoch
204+
python main.py fit --optimizer=Adam --lr_scheduler=ReduceLROnPlateau --lr_scheduler.monitor=train_loss
205205
206-
Furthermore, any custom subclass of ``torch.optim.lr_scheduler.LRScheduler`` can be used as learning rate scheduler:
206+
(assuming you have a ``train_loss`` metric logged). Furthermore, any custom subclass of
207+
``torch.optim.lr_scheduler.LRScheduler`` can be used as learning rate scheduler:
207208

208209
.. code:: python
209210
@@ -212,7 +213,6 @@ Furthermore, any custom subclass of ``torch.optim.lr_scheduler.LRScheduler`` can
212213
from lightning.pytorch.cli import LightningCLI
213214
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
214215
215-
216216
class LitLRScheduler(torch.optim.lr_scheduler.CosineAnnealingLR):
217217
def step(self):
218218
print("", "using LitLRScheduler", "")

src/lightning/pytorch/cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@
6666

6767

6868
class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
69+
"""Custom ReduceLROnPlateau scheduler that extends PyTorch's ReduceLROnPlateau.
70+
71+
This class adds a `monitor` attribute to the standard PyTorch ReduceLROnPlateau to specify which metric should be
72+
tracked for learning rate adjustment.
73+
74+
"""
75+
6976
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
7077
super().__init__(optimizer, *args, **kwargs)
7178
self.monitor = monitor

tests/tests_pytorch/utilities/test_model_summary.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,13 +319,16 @@ def test_empty_model_size(max_depth):
319319

320320

321321
@pytest.mark.parametrize(
322-
"accelerator",
322+
("accelerator", "precision"),
323323
[
324-
pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)),
325-
pytest.param("mps", marks=RunIf(mps=True)),
324+
pytest.param("gpu", "16-true", marks=RunIf(min_cuda_gpus=1)),
325+
pytest.param("gpu", "32-true", marks=RunIf(min_cuda_gpus=1)),
326+
pytest.param("gpu", "64-true", marks=RunIf(min_cuda_gpus=1)),
327+
pytest.param("mps", "16-true", marks=RunIf(mps=True)),
328+
pytest.param("mps", "32-true", marks=RunIf(mps=True)),
329+
# Note: "64-true" with "mps" is skipped because MPS does not support float64
326330
],
327331
)
328-
@pytest.mark.parametrize("precision", ["16-true", "32-true", "64-true"])
329332
def test_model_size_precision(tmp_path, accelerator, precision):
330333
"""Test model size for different precision types."""
331334
model = PreCalculatedModel(precision=int(precision.split("-")[0]))

0 commit comments

Comments
 (0)