Skip to content

Commit b09e96e

Browse files
Fix ModelCheckpoint file_exists OOM in DDP (#21380)
* Fix ModelCheckpoint.file_exists OOM in DDP * Document ModelCheckpoint.file_exists DDP memory fix * Update src/lightning/pytorch/callbacks/model_checkpoint.py --------- Co-authored-by: Justus Schock <[email protected]>
1 parent ef489f2 commit b09e96e

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8282
- Fixed FSDP mixed precision semantics and added user warning ([#21361](https://github.com/Lightning-AI/pytorch-lightning/pull/21361))
8383

8484

85+
- Fixed `ModelCheckpoint.file_exists` using broadcast in DDP, reducing memory usage when checking for existing checkpoints ([#19674](https://github.com/Lightning-AI/pytorch-lightning/issues/19674))
86+
87+
8588
---
8689

8790
## [2.5.6] - 2025-11-05

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -997,10 +997,12 @@ def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
997997
yaml.dump(best_k, fp)
998998

999999
def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
1000-
"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
1000+
"""Checks if a file exists on rank 0 and synchronizes the result to all other ranks, preventing the internal
10011001
state to diverge between ranks."""
1002-
exists = self._fs.exists(filepath)
1003-
return trainer.strategy.broadcast(exists)
1002+
# In distributed setups, only global rank 0 touches the filesystem
1003+
local_decision = self._fs.exists(filepath) if trainer.is_global_zero else False
1004+
# Reduce the decision across ranks using an "any"-style reduction to decide if the file exists anywhere
1005+
return trainer.strategy.reduce_boolean_decision(local_decision, all=False)
10041006

10051007
def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool:
10061008
"""Checks if the previous checkpoint should be deleted.

tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,28 @@ def on_train_epoch_end(self):
121121
trainer.fit(model)
122122
if os.getenv("LOCAL_RANK") == "0":
123123
assert save_mock.call_count == expected
124+
125+
126+
@RunIf(min_cuda_gpus=2, standalone=True)
127+
def test_model_checkpoint_ddp_monitor_none(tmp_path):
128+
"""Ensure that ModelCheckpoint with monitor=None works correctly under DDP and exercises the file_exists path."""
129+
130+
model = BoringModel()
131+
checkpoint = callbacks.ModelCheckpoint(dirpath=tmp_path, monitor=None, save_top_k=1)
132+
133+
trainer = Trainer(
134+
default_root_dir=tmp_path,
135+
callbacks=[checkpoint],
136+
enable_progress_bar=False,
137+
enable_model_summary=False,
138+
max_epochs=1,
139+
strategy="ddp",
140+
accelerator="gpu",
141+
devices=2,
142+
limit_train_batches=2,
143+
limit_val_batches=0,
144+
)
145+
146+
trainer.fit(model)
147+
if os.getenv("LOCAL_RANK") == "0":
148+
assert checkpoint.best_model_path

0 commit comments

Comments
 (0)