Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `AsyncCheckpointIO` snapshots tensors to avoid race with parameter mutation ([#21079](https://github.com/Lightning-AI/pytorch-lightning/pull/21079))


- Fixed `AsyncCheckpointIO` threadpool exception if calling fit or validate more than one ([#20952](https://github.com/Lightning-AI/pytorch-lightning/pull/20952))


- Fixed learning rate not being correctly set after using `LearningRateFinder` callback ([#21068](https://github.com/Lightning-AI/pytorch-lightning/pull/21068))


---

## [2.5.3] - 2025-08-13
Expand Down
40 changes: 31 additions & 9 deletions src/lightning/pytorch/plugins/io/async_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
# limitations under the License.

from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from typing_extensions import override

from lightning.fabric.plugins import CheckpointIO
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO

if TYPE_CHECKING:
from lightning.fabric.plugins import CheckpointIO


class AsyncCheckpointIO(_WrappingCheckpointIO):
"""``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread.
Expand All @@ -33,20 +35,30 @@ class AsyncCheckpointIO(_WrappingCheckpointIO):

"""

_executor: Optional[ThreadPoolExecutor]
_error: Optional[BaseException]

def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None:
super().__init__(checkpoint_io)
self._executor = None
self._error = None

# CheckpointIO doesn't have a setup method so we have to do something like.
def _ensure_setup(self) -> None:
"""Ensures that the executor is setup.

self._executor = ThreadPoolExecutor(max_workers=1)
self._error: Optional[BaseException] = None
We can't do setup in __init__ because if train or validate is called more than once, the teardown method deletes
the executor.

"""
if self._executor is None:
self._executor = ThreadPoolExecutor(max_workers=1)

@override
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
"""Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``."""

# snapshot the checkpoint payload on the caller thread to avoid races with parameter mutation
def _clone_tensor(t: torch.Tensor) -> torch.Tensor:
# detach to avoid autograd history and clone to take a point-in-time copy
return t.detach().clone()
self._ensure_setup()

# rebuild args/kwargs with a cloned checkpoint (supports positional or kw form)
if "checkpoint" in kwargs:
Expand All @@ -61,6 +73,7 @@ def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
except BaseException as ex:
self._error = ex

assert self._executor is not None
self._executor.submit(_save_checkpoint, *args, **kwargs)

# if an error was raised between the previous time `save_checkpoint`` was called and now,
Expand All @@ -71,8 +84,17 @@ def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
@override
def teardown(self) -> None:
"""This method is called to close the threads."""
self._executor.shutdown(wait=True)
if self._executor is not None:
self._executor.shutdown(wait=True)
self._executor = None

# if an error was raised anytime in any of the `executor.submit` calls
if self._error:
raise self._error


# snapshot the checkpoint payload on the caller thread to avoid races with parameter mutation
def _clone_tensor(t: torch.Tensor) -> torch.Tensor:
"""Clones a tensor on the caller thread."""
# detach to avoid autograd history and clone to take a point-in-time copy
return t.detach().clone()
4 changes: 4 additions & 0 deletions tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def on_fit_start(self):
enable_progress_bar=False,
enable_model_summary=False,
)

# We add a validate step to test that async works when fit or validate is called multiple times.
trainer.validate(model)

trainer.fit(model)

assert checkpoint_plugin.save_checkpoint.call_count == 3
Expand Down
Loading