Skip to content

Commit 21f9f28

Browse files
committed
Make asyncio checkpointing work if validate/fit is called more than once.
1 parent afa7d56 commit 21f9f28

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626
### Fixed
2727

28-
-
28+
- Fixed `AsyncCheckpointIO` threadpool exception if calling fit or validate more than one.
2929

3030

3131
---

src/lightning/pytorch/plugins/io/async_plugin.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,23 @@ class AsyncCheckpointIO(_WrappingCheckpointIO):
3333

3434
def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None:
3535
super().__init__(checkpoint_io)
36-
37-
self._executor = ThreadPoolExecutor(max_workers=1)
36+
self._executor: Optional[ThreadPoolExecutor] = None
3837
self._error: Optional[BaseException] = None
3938

39+
# CheckpointIO doesn't have a setup method so we have to do something like.
40+
# We can't do setup in __init__ because if train or validate is called more than once the
41+
# teardown method deletes the executor.
42+
def _ensure_setup(self) -> None:
43+
if self._executor is None:
44+
self._executor = ThreadPoolExecutor(max_workers=1)
45+
self._error: Optional[BaseException] = None
46+
4047
@override
4148
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
4249
"""Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``."""
4350

51+
self._ensure_setup()
52+
4453
def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
4554
try:
4655
assert self.checkpoint_io is not None
@@ -58,8 +67,10 @@ def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
5867
@override
5968
def teardown(self) -> None:
6069
"""This method is called to close the threads."""
61-
self._executor.shutdown(wait=True)
70+
if self._executor is not None:
71+
self._executor.shutdown(wait=True)
72+
self._executor = None
6273

63-
# if an error was raised anytime in any of the `executor.submit` calls
64-
if self._error:
65-
raise self._error
74+
# if an error was raised anytime in any of the `executor.submit` calls
75+
if self._error:
76+
raise self._error

tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ def on_fit_start(self):
127127
enable_progress_bar=False,
128128
enable_model_summary=False,
129129
)
130+
131+
# We add a validate step to test that async works when fit or validate is called multiple times.
132+
trainer.validate(model)
133+
130134
trainer.fit(model)
131135

132136
assert checkpoint_plugin.save_checkpoint.call_count == 3

0 commit comments

Comments
 (0)