Skip to content

Commit aa4cef6

Browse files
committed
add test coverage
1 parent 0ac7fcf commit aa4cef6

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8282
- Fixed FSDP mixed precision semantics and added user warning ([#21361](https://github.com/Lightning-AI/pytorch-lightning/pull/21361))
8383

8484

85+
- Fixed `ModelCheckpoint.file_exists` using broadcast in DDP, reducing memory usage when checking for existing checkpoints ([#19674](https://github.com/Lightning-AI/pytorch-lightning/issues/19674))
86+
87+
8588
---
8689

8790
## [2.5.6] - 2025-11-05

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -999,10 +999,6 @@ def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
999999
def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
10001000
"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
10011001
state to diverge between ranks."""
1002-
# Single-process or strategies without distributed world size: no need for coordination
1003-
if trainer.world_size == 1:
1004-
return self._fs.exists(filepath)
1005-
10061002
# In distributed setups, only global rank 0 touches the filesystem
10071003
local_decision = self._fs.exists(filepath) if trainer.is_global_zero else False
10081004
# Reduce the decision across ranks using an "any"-style reduction to decide if the file exists anywhere

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2180,3 +2180,36 @@ def on_validation_epoch_end(self):
21802180
assert len(checkpoint_files) == expected_files, (
21812181
f"Expected {expected_files} files, got {len(checkpoint_files)}: {checkpoint_names}"
21822182
)
2183+
2184+
2185+
def test_model_checkpoint_file_exists_distributed_branch(tmp_path):
2186+
"""Ensure the distributed branch of ModelCheckpoint.file_exists uses reduce_boolean_decision."""
2187+
2188+
checkpoint = ModelCheckpoint(dirpath=tmp_path)
2189+
calls = []
2190+
2191+
class DummyStrategy:
2192+
def reduce_boolean_decision(self, decision, all=True):
2193+
calls.append((decision, all))
2194+
return decision
2195+
2196+
class DummyTrainer:
2197+
def __init__(self, is_global_zero: bool):
2198+
self.world_size = 2
2199+
self.is_global_zero = is_global_zero
2200+
self.strategy = DummyStrategy()
2201+
2202+
# global rank 0: filesystem is touched and decision=True is reduced with all=False
2203+
checkpoint._fs.exists = Mock(return_value=True)
2204+
trainer = DummyTrainer(is_global_zero=True)
2205+
assert checkpoint.file_exists("ignored", trainer)
2206+
checkpoint._fs.exists.assert_called_once_with("ignored")
2207+
assert calls == [(True, False)]
2208+
2209+
# non-global ranks: filesystem is not touched and local decision is False
2210+
calls.clear()
2211+
checkpoint._fs.exists = Mock(return_value=True)
2212+
trainer = DummyTrainer(is_global_zero=False)
2213+
assert not checkpoint.file_exists("ignored", trainer)
2214+
checkpoint._fs.exists.assert_not_called()
2215+
assert calls == [(False, False)]

0 commit comments

Comments
 (0)