diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index aaa9e640a8110..9c7623a941dd5 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -26,6 +26,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `AsyncCheckpointIO` threadpool exception if calling fit or validate more than one ([#20952](https://github.com/Lightning-AI/pytorch-lightning/pull/20952)) + + - fix progress bar console clearing for Rich `14.1+` ([#21016](https://github.com/Lightning-AI/pytorch-lightning/pull/21016)) diff --git a/src/lightning/pytorch/plugins/io/async_plugin.py b/src/lightning/pytorch/plugins/io/async_plugin.py index 67c02189c541e..0c1e3e55c03cb 100644 --- a/src/lightning/pytorch/plugins/io/async_plugin.py +++ b/src/lightning/pytorch/plugins/io/async_plugin.py @@ -33,14 +33,23 @@ class AsyncCheckpointIO(_WrappingCheckpointIO): def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None: super().__init__(checkpoint_io) - - self._executor = ThreadPoolExecutor(max_workers=1) + self._executor: Optional[ThreadPoolExecutor] = None self._error: Optional[BaseException] = None + # CheckpointIO doesn't have a setup method so we have to do something like. + # We can't do setup in __init__ because if train or validate is called more than once the + # teardown method deletes the executor. + def _ensure_setup(self) -> None: + if self._executor is None: + self._executor = ThreadPoolExecutor(max_workers=1) + self._error: Optional[BaseException] = None + @override def save_checkpoint(self, *args: Any, **kwargs: Any) -> None: """Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``.""" + self._ensure_setup() + def _save_checkpoint(*args: Any, **kwargs: Any) -> None: try: assert self.checkpoint_io is not None @@ -58,8 +67,10 @@ def _save_checkpoint(*args: Any, **kwargs: Any) -> None: @override def teardown(self) -> None: """This method is called to close the threads.""" - self._executor.shutdown(wait=True) + if self._executor is not None: + self._executor.shutdown(wait=True) + self._executor = None - # if an error was raised anytime in any of the `executor.submit` calls - if self._error: - raise self._error + # if an error was raised anytime in any of the `executor.submit` calls + if self._error: + raise self._error diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index 0f62eeae69ef8..f7a76079cfca2 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -127,6 +127,10 @@ def on_fit_start(self): enable_progress_bar=False, enable_model_summary=False, ) + + # We add a validate step to test that async works when fit or validate is called multiple times. + trainer.validate(model) + trainer.fit(model) assert checkpoint_plugin.save_checkpoint.call_count == 3