Skip to content

Commit 54c84c7

Browse files
authored
Merge branch 'master' into generic-weight-averaging
2 parents 6012908 + 3ed9d4e commit 54c84c7

File tree

5 files changed

+116
-10
lines changed

5 files changed

+116
-10
lines changed

src/lightning/fabric/utilities/distributed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,11 @@ def _destroy_dist_connection() -> None:
319319

320320

321321
def _get_default_process_group_backend_for_device(device: torch.device) -> str:
322-
return "nccl" if device.type == "cuda" else "gloo"
322+
"""Return corresponding distributed backend for a given device."""
323+
device_backend_map = torch.distributed.Backend.default_device_backend_map
324+
if device.type in device_backend_map:
325+
return device_backend_map[device.type]
326+
return "gloo"
323327

324328

325329
class _DatasetSamplerWrapper(Dataset):

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828
-
2929

3030

31+
- Fixed learning rate not being correctly set after using `LearningRateFinder` callback ([#21068](https://github.com/Lightning-AI/pytorch-lightning/pull/21068))
32+
3133
---
3234

3335
## [2.5.3] - 2025-08-13

src/lightning/pytorch/tuner/lr_finder.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,24 +276,30 @@ def _lr_find(
276276
if trainer.progress_bar_callback:
277277
trainer.progress_bar_callback.enable()
278278

279-
# Update lr attr if required
279+
# Update results across ranks
280280
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
281-
if update_attr:
282-
lr = lr_finder.suggestion()
283-
284-
# TODO: log lr.results to self.logger
285-
if lr is not None:
286-
lightning_setattr(model, attr_name, lr)
287-
log.info(f"Learning rate set to {lr}")
288281

289-
# Restore initial state of model
282+
# Restore initial state of model (this will also restore the original optimizer state)
290283
trainer._checkpoint_connector.restore(ckpt_path)
291284
trainer.strategy.remove_checkpoint(ckpt_path)
292285
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
293286
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
294287
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
295288
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
296289
trainer.fit_loop.setup_data()
290+
291+
# Apply LR suggestion after restoring so it persists for the real training run
292+
# When used as a callback, the suggestion would otherwise be lost due to checkpoint restore
293+
if update_attr:
294+
lr = lr_finder.suggestion()
295+
if lr is not None:
296+
# update the attribute on the LightningModule (e.g., lr or learning_rate)
297+
lightning_setattr(model, attr_name, lr)
298+
# also update the currently active optimizer(s) so training continues with the suggested LR
299+
for opt in trainer.optimizers or []:
300+
for pg in opt.param_groups:
301+
pg["lr"] = lr
302+
log.info(f"Learning rate set to {lr}")
297303
return lr_finder
298304

299305

tests/tests_fabric/utilities/test_distributed.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from lightning.fabric.utilities.distributed import (
1818
_destroy_dist_connection,
1919
_gather_all_tensors,
20+
_get_default_process_group_backend_for_device,
2021
_InfiniteBarrier,
2122
_init_dist_connection,
2223
_is_dtensor,
@@ -243,6 +244,27 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
243244
atexit_mock.register.assert_not_called()
244245

245246

247+
def test_get_default_process_group_backend_for_device():
248+
"""Test that each device type maps to its correct default process group backend."""
249+
# register a custom backend for test
250+
torch.utils.rename_privateuse1_backend("pcu")
251+
252+
def mock_backend(store, group_rank, group_size, timeout):
253+
pass
254+
255+
torch.distributed.Backend.register_backend(
256+
"pccl",
257+
lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout),
258+
devices=["pcu"],
259+
)
260+
261+
# test that the default backend is correctly set for each device
262+
devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("pcu:0")]
263+
backends = ["gloo", "nccl", "pccl"]
264+
for device, backend in zip(devices, backends):
265+
assert _get_default_process_group_backend_for_device(device) == backend
266+
267+
246268
@RunIf(min_torch="2.4")
247269
def test_is_dtensor(monkeypatch):
248270
from torch.distributed._tensor import DTensor

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,78 @@ def test_gradient_correctness():
619619
assert abs(suggestion - math.pi) < 1e-2, "Suggestion should be close to pi for this synthetic example"
620620

621621

622+
def test_lr_finder_callback_applies_lr_after_restore(tmp_path):
623+
"""LearningRateFinder used as a callback should apply its suggested LR to the optimizer used after state
624+
restoration."""
625+
626+
import torch.nn as nn
627+
import torch.nn.functional as F
628+
from torch.utils.data import DataLoader, Dataset
629+
630+
from lightning.pytorch.callbacks import LearningRateMonitor
631+
632+
class RandomDataset(Dataset):
633+
def __init__(self, n: int = 256, in_dim: int = 28 * 28):
634+
self.x = torch.randn(n, in_dim)
635+
self.y = torch.randn(n, in_dim)
636+
637+
def __len__(self) -> int:
638+
return len(self.x)
639+
640+
def __getitem__(self, idx):
641+
return self.x[idx], self.y[idx]
642+
643+
class TinyAE(BoringModel):
644+
def __init__(self, lr: float = 1e-5):
645+
super().__init__()
646+
self.save_hyperparameters()
647+
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
648+
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
649+
650+
def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
651+
x, y = batch
652+
z = self.encoder(x)
653+
x_hat = self.decoder(z)
654+
loss = F.mse_loss(x_hat, y)
655+
return loss
656+
657+
def configure_optimizers(self):
658+
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
659+
660+
seed_everything(123)
661+
662+
ds = RandomDataset(n=512)
663+
train_loader = DataLoader(ds, batch_size=64, shuffle=False)
664+
665+
model = TinyAE(lr=1e-5)
666+
667+
lr_finder_cb = LearningRateFinder() # default update_attr=True should apply suggestion
668+
lr_monitor = LearningRateMonitor(logging_interval="step")
669+
670+
trainer = Trainer(
671+
default_root_dir=tmp_path,
672+
max_epochs=2,
673+
callbacks=[lr_finder_cb, lr_monitor],
674+
enable_model_summary=False,
675+
enable_progress_bar=False,
676+
log_every_n_steps=1,
677+
)
678+
679+
trainer.fit(model, train_loader)
680+
assert model.hparams.lr is not None
681+
# Ensure LR Finder produced a suggestion for this setup; if not, the test can't assert application
682+
assert lr_finder_cb.optimal_lr is not None, "LR Finder should have computed results"
683+
suggestion = lr_finder_cb.optimal_lr.suggestion()
684+
assert suggestion is not None, "LR Finder should produce a suggestion for this setup"
685+
686+
# Verify that the optimizer used for subsequent training has the suggested LR applied
687+
assert trainer.optimizers, "Trainer should have an optimizer after fit"
688+
current_lr = trainer.optimizers[0].param_groups[0]["lr"]
689+
assert current_lr == pytest.approx(suggestion), (
690+
f"LR Finder suggestion {suggestion} should be applied to optimizer, but got {current_lr}"
691+
)
692+
693+
622694
def test_exponential_vs_linear_mode_gradient_difference(tmp_path):
623695
"""Test that exponential and linear modes produce different but valid suggestions.
624696

0 commit comments

Comments
 (0)