Skip to content

Commit 0ac7fcf

Browse files
committed
Fix ModelCheckpoint file_exists OOM in DDP
- Use strategy.reduce_boolean_decision instead of broadcast in ModelCheckpoint.file_exists - Ensure only global rank 0 touches the filesystem - Avoid broadcast_object_list for a simple boolean in DDP - Add a small DDP test with monitor=None to exercise this path
1 parent 8f702b3 commit 0ac7fcf

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -999,8 +999,14 @@ def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
999999
def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
10001000
"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
10011001
state to diverge between ranks."""
1002-
exists = self._fs.exists(filepath)
1003-
return trainer.strategy.broadcast(exists)
1002+
# Single-process or strategies without distributed world size: no need for coordination
1003+
if trainer.world_size == 1:
1004+
return self._fs.exists(filepath)
1005+
1006+
# In distributed setups, only global rank 0 touches the filesystem
1007+
local_decision = self._fs.exists(filepath) if trainer.is_global_zero else False
1008+
# Reduce the decision across ranks using an "any"-style reduction to decide if the file exists anywhere
1009+
return trainer.strategy.reduce_boolean_decision(local_decision, all=False)
10041010

10051011
def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool:
10061012
"""Checks if the previous checkpoint should be deleted.

tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,28 @@ def on_train_epoch_end(self):
121121
trainer.fit(model)
122122
if os.getenv("LOCAL_RANK") == "0":
123123
assert save_mock.call_count == expected
124+
125+
126+
@RunIf(min_cuda_gpus=2, standalone=True)
127+
def test_model_checkpoint_ddp_monitor_none(tmp_path):
128+
"""Ensure that ModelCheckpoint with monitor=None works correctly under DDP and exercises the file_exists path."""
129+
130+
model = BoringModel()
131+
checkpoint = callbacks.ModelCheckpoint(dirpath=tmp_path, monitor=None, save_top_k=1)
132+
133+
trainer = Trainer(
134+
default_root_dir=tmp_path,
135+
callbacks=[checkpoint],
136+
enable_progress_bar=False,
137+
enable_model_summary=False,
138+
max_epochs=1,
139+
strategy="ddp",
140+
accelerator="gpu",
141+
devices=2,
142+
limit_train_batches=2,
143+
limit_val_batches=0,
144+
)
145+
146+
trainer.fit(model)
147+
if os.getenv("LOCAL_RANK") == "0":
148+
assert checkpoint.best_model_path

0 commit comments

Comments
 (0)