Skip to content

Commit 8cefd43

Browse files
authored
Merge branch 'master' into fix/trainer_init_module_device
2 parents e4c2ac5 + 3c81316 commit 8cefd43

File tree

14 files changed

+208
-51
lines changed

14 files changed

+208
-51
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222

2323
### Changed
2424

25-
-
25+
- let `_get_default_process_group_backend_for_device` support more hardware platforms (
26+
[#21057](https://github.com/Lightning-AI/pytorch-lightning/pull/21057), [#21093](https://github.com/Lightning-AI/pytorch-lightning/pull/21093))
2627

2728

2829
### Fixed
2930

3031
- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))
3132

3233

34+
- Respect `verbose=False` in `seed_everything` when no seed is provided
35+
36+
3337
---
3438

3539
## [2.5.4] - 2025-08-29

src/lightning/fabric/strategies/ddp.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,17 @@ def barrier(self, *args: Any, **kwargs: Any) -> None:
160160
if torch.distributed.get_backend() == "nccl":
161161
torch.distributed.barrier(device_ids=self._determine_ddp_device_ids())
162162
else:
163-
torch.distributed.barrier()
163+
# Handle PyTorch bug where barrier() fails on CPU with "PrivateUse1HooksInterface" error
164+
try:
165+
torch.distributed.barrier()
166+
except RuntimeError as e:
167+
if "PrivateUse1HooksInterface" in str(e):
168+
# Fallback: Use all_reduce as barrier - all processes must participate
169+
# This achieves the same synchronization effect as barrier()
170+
dummy_tensor = torch.tensor(0.0, device=self.root_device)
171+
torch.distributed.all_reduce(dummy_tensor)
172+
else:
173+
raise
164174

165175
@override
166176
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:

src/lightning/fabric/utilities/distributed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,11 @@ def _destroy_dist_connection() -> None:
319319

320320

321321
def _get_default_process_group_backend_for_device(device: torch.device) -> str:
322-
return "nccl" if device.type == "cuda" else "gloo"
322+
"""Return corresponding distributed backend for a given device."""
323+
device_backend_map = torch.distributed.Backend.default_device_backend_map
324+
if device.type in device_backend_map:
325+
return device_backend_map[device.type]
326+
return "gloo"
323327

324328

325329
class _DatasetSamplerWrapper(Dataset):

src/lightning/fabric/utilities/seed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose:
4040
env_seed = os.environ.get("PL_GLOBAL_SEED")
4141
if env_seed is None:
4242
seed = 0
43-
rank_zero_warn(f"No seed found, seed set to {seed}")
43+
if verbose:
44+
rank_zero_warn(f"No seed found, seed set to {seed}")
4445
else:
4546
try:
4647
seed = int(env_seed)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4545

4646
- Fixed `TQDMProgressBar` not resetting correctly when using both a finite and iterable dataloader ([#21147](https://github.com/Lightning-AI/pytorch-lightning/pull/21147))
4747

48+
49+
- Fixed cleanup of temporary files from `Tuner` on crashes ([#21162](https://github.com/Lightning-AI/pytorch-lightning/pull/21162))
50+
4851
---
4952

5053
## [2.5.4] - 2025-08-29

src/lightning/pytorch/tuner/batch_size_scaling.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,24 +76,27 @@ def _scale_batch_size(
7676
if trainer.progress_bar_callback:
7777
trainer.progress_bar_callback.disable()
7878

79-
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)
80-
81-
if mode == "power":
82-
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
83-
elif mode == "binsearch":
84-
new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)
79+
try:
80+
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)
8581

86-
garbage_collection_cuda()
82+
if mode == "power":
83+
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
84+
elif mode == "binsearch":
85+
new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)
8786

88-
log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
87+
garbage_collection_cuda()
8988

90-
__scale_batch_restore_params(trainer, params)
89+
log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
90+
except Exception as ex:
91+
raise ex
92+
finally:
93+
__scale_batch_restore_params(trainer, params)
9194

92-
if trainer.progress_bar_callback:
93-
trainer.progress_bar_callback.enable()
95+
if trainer.progress_bar_callback:
96+
trainer.progress_bar_callback.enable()
9497

95-
trainer._checkpoint_connector.restore(ckpt_path)
96-
trainer.strategy.remove_checkpoint(ckpt_path)
98+
trainer._checkpoint_connector.restore(ckpt_path)
99+
trainer.strategy.remove_checkpoint(ckpt_path)
97100

98101
return new_size
99102

src/lightning/pytorch/tuner/lr_finder.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -257,40 +257,45 @@ def _lr_find(
257257
# Initialize lr finder object (stores results)
258258
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
259259

260-
# Configure optimizer and scheduler
261-
lr_finder._exchange_scheduler(trainer)
262-
263-
# Fit, lr & loss logged in callback
264-
_try_loop_run(trainer, params)
265-
266-
# Prompt if we stopped early
267-
if trainer.global_step != num_training + start_steps:
268-
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")
269-
270-
# Transfer results from callback to lr finder object
271-
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
272-
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose
273-
274-
__lr_finder_restore_params(trainer, params)
275-
276-
if trainer.progress_bar_callback:
277-
trainer.progress_bar_callback.enable()
278-
279-
# Update results across ranks
280-
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
281-
282-
# Restore initial state of model (this will also restore the original optimizer state)
283-
trainer._checkpoint_connector.restore(ckpt_path)
284-
trainer.strategy.remove_checkpoint(ckpt_path)
285-
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
286-
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
287-
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
288-
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
289-
trainer.fit_loop.setup_data()
260+
lr_finder_finished = False
261+
try:
262+
# Configure optimizer and scheduler
263+
lr_finder._exchange_scheduler(trainer)
264+
265+
# Fit, lr & loss logged in callback
266+
_try_loop_run(trainer, params)
267+
268+
# Prompt if we stopped early
269+
if trainer.global_step != num_training + start_steps:
270+
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")
271+
272+
# Transfer results from callback to lr finder object
273+
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
274+
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose
275+
276+
__lr_finder_restore_params(trainer, params)
277+
278+
if trainer.progress_bar_callback:
279+
trainer.progress_bar_callback.enable()
280+
281+
# Update results across ranks
282+
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
283+
lr_finder_finished = True
284+
except Exception as ex:
285+
raise ex
286+
finally:
287+
# Restore initial state of model (this will also restore the original optimizer state)
288+
trainer._checkpoint_connector.restore(ckpt_path)
289+
trainer.strategy.remove_checkpoint(ckpt_path)
290+
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
291+
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
292+
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
293+
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
294+
trainer.fit_loop.setup_data()
290295

291296
# Apply LR suggestion after restoring so it persists for the real training run
292297
# When used as a callback, the suggestion would otherwise be lost due to checkpoint restore
293-
if update_attr:
298+
if update_attr and lr_finder_finished:
294299
lr = lr_finder.suggestion()
295300
if lr is not None:
296301
# update the attribute on the LightningModule (e.g., lr or learning_rate)

tests/tests_fabric/utilities/test_distributed.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from lightning.fabric.utilities.distributed import (
1818
_destroy_dist_connection,
1919
_gather_all_tensors,
20+
_get_default_process_group_backend_for_device,
2021
_InfiniteBarrier,
2122
_init_dist_connection,
2223
_is_dtensor,
@@ -243,6 +244,27 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
243244
atexit_mock.register.assert_not_called()
244245

245246

247+
def test_get_default_process_group_backend_for_device():
248+
"""Test that each device type maps to its correct default process group backend."""
249+
# register a custom backend for test
250+
torch.utils.rename_privateuse1_backend("pcu")
251+
252+
def mock_backend(store, group_rank, group_size, timeout):
253+
pass
254+
255+
torch.distributed.Backend.register_backend(
256+
"pccl",
257+
lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout),
258+
devices=["pcu"],
259+
)
260+
261+
# test that the default backend is correctly set for each device
262+
devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("pcu:0")]
263+
backends = ["gloo", "nccl", "pccl"]
264+
for device, backend in zip(devices, backends):
265+
assert _get_default_process_group_backend_for_device(device) == backend
266+
267+
246268
@RunIf(min_torch="2.4")
247269
def test_is_dtensor(monkeypatch):
248270
from torch.distributed._tensor import DTensor

tests/tests_fabric/utilities/test_seed.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ def test_seed_everything_accepts_valid_seed_from_env():
7272
assert seed_everything() == 17
7373

7474

75+
@mock.patch.dict(os.environ, {}, clear=True)
76+
def test_seed_everything_non_verbose_no_warning():
77+
"""Ensure that no warning is emitted when verbose is False and no seed is provided."""
78+
with warnings.catch_warnings(record=True) as caught:
79+
seed_everything(verbose=False)
80+
assert caught == []
81+
82+
7583
def test_reset_seed_no_op():
7684
"""Test that the reset_seed function is a no-op when seed_everything() was not used."""
7785
assert "PL_GLOBAL_SEED" not in os.environ

tests/tests_fabric/utilities/test_spike.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
3030
)
3131

3232

33-
@pytest.mark.flaky(max_runs=3)
33+
@pytest.mark.flaky(reruns=3)
3434
@pytest.mark.parametrize(
3535
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
3636
# NOTE FOR ALL FOLLOWING TESTS:

0 commit comments

Comments
 (0)