Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `AsyncCheckpointIO` threadpool exception if calling fit or validate more than one.


- fix progress bar console clearing for Rich `14.1+` ([#21016](https://github.com/Lightning-AI/pytorch-lightning/pull/21016))


Expand Down
23 changes: 17 additions & 6 deletions src/lightning/pytorch/plugins/io/async_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,23 @@ class AsyncCheckpointIO(_WrappingCheckpointIO):

def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None:
super().__init__(checkpoint_io)

self._executor = ThreadPoolExecutor(max_workers=1)
self._executor: Optional[ThreadPoolExecutor] = None
self._error: Optional[BaseException] = None

# CheckpointIO doesn't have a setup method so we have to do something like.
# We can't do setup in __init__ because if train or validate is called more than once the
# teardown method deletes the executor.
def _ensure_setup(self) -> None:
if self._executor is None:
self._executor = ThreadPoolExecutor(max_workers=1)
self._error: Optional[BaseException] = None

@override
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
"""Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``."""

self._ensure_setup()

def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
try:
assert self.checkpoint_io is not None
Expand All @@ -58,8 +67,10 @@ def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
@override
def teardown(self) -> None:
"""This method is called to close the threads."""
self._executor.shutdown(wait=True)
if self._executor is not None:
self._executor.shutdown(wait=True)
self._executor = None

# if an error was raised anytime in any of the `executor.submit` calls
if self._error:
raise self._error
# if an error was raised anytime in any of the `executor.submit` calls
if self._error:
raise self._error
4 changes: 4 additions & 0 deletions tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def on_fit_start(self):
enable_progress_bar=False,
enable_model_summary=False,
)

# We add a validate step to test that async works when fit or validate is called multiple times.
trainer.validate(model)

trainer.fit(model)

assert checkpoint_plugin.save_checkpoint.call_count == 3
Expand Down