diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index db85bcd1adfaf..107c389dba590 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -28,7 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed `AsyncCheckpointIO` snapshots tensors to avoid race with parameter mutation ([#21079](https://github.com/Lightning-AI/pytorch-lightning/pull/21079)) - Fixed learning rate not being correctly set after using `LearningRateFinder` callback ([#21068](https://github.com/Lightning-AI/pytorch-lightning/pull/21068)) diff --git a/src/lightning/pytorch/plugins/io/async_plugin.py b/src/lightning/pytorch/plugins/io/async_plugin.py index 67c02189c541e..d174a3d0ed1ea 100644 --- a/src/lightning/pytorch/plugins/io/async_plugin.py +++ b/src/lightning/pytorch/plugins/io/async_plugin.py @@ -15,6 +15,8 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, Optional +import torch +from lightning_utilities.core.apply_func import apply_to_collection from typing_extensions import override from lightning.fabric.plugins import CheckpointIO @@ -41,6 +43,17 @@ def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None: def save_checkpoint(self, *args: Any, **kwargs: Any) -> None: """Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``.""" + # snapshot the checkpoint payload on the caller thread to avoid races with parameter mutation + def _clone_tensor(t: torch.Tensor) -> torch.Tensor: + # detach to avoid autograd history and clone to take a point-in-time copy + return t.detach().clone() + + # rebuild args/kwargs with a cloned checkpoint (supports positional or kw form) + if "checkpoint" in kwargs: + kwargs = {**kwargs, "checkpoint": apply_to_collection(kwargs["checkpoint"], torch.Tensor, _clone_tensor)} + elif len(args) >= 1: + args = (apply_to_collection(args[0], torch.Tensor, _clone_tensor), *args[1:]) + def _save_checkpoint(*args: Any, **kwargs: Any) -> None: try: assert self.checkpoint_io is not None diff --git a/tests/tests_pytorch/plugins/test_async_checkpoint.py b/tests/tests_pytorch/plugins/test_async_checkpoint.py new file mode 100644 index 0000000000000..0718dab78d75f --- /dev/null +++ b/tests/tests_pytorch/plugins/test_async_checkpoint.py @@ -0,0 +1,53 @@ +import time +from typing import Any, Optional + +import pytest +import torch + +from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO +from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO + + +class _CaptureCheckpointIO(CheckpointIO): + def __init__(self) -> None: + self.saved: Optional[dict[str, Any]] = None + + def save_checkpoint(self, checkpoint: dict[str, Any], path: str, storage_options: Optional[Any] = None) -> None: + # Simulate some delay to increase race window + time.sleep(0.05) + # Store the received checkpoint object (not a deep copy) to inspect tensor values + self.saved = checkpoint + + def load_checkpoint(self, path: str, map_location: Optional[Any] = None) -> dict[str, Any]: + raise NotImplementedError + + def remove_checkpoint(self, path: str) -> None: + pass + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_async_checkpoint_should_snapshot_values_before_mutation(): + base = _CaptureCheckpointIO() + async_io = AsyncCheckpointIO(checkpoint_io=base) + + # a tensor that we will mutate after scheduling the save + t = torch.tensor([0.0]) + ckpt = {"w": t} + + # schedule async save + async_io.save_checkpoint(ckpt, path="unused") + + # mutate immediately afterward to mimic training thread stepping params + t.add_(1.0) + + # ensure background thread finished + async_io.teardown() + + assert base.saved is not None, "Async save did not run" + + # EXPECTATION: AsyncCheckpointIO should have captured value 0.0 (pre-mutation) + # CURRENT BEHAVIOR (bug): it captures 1.0 because the dict holds references + assert torch.allclose(base.saved["w"], torch.tensor([0.0])), ( + "AsyncCheckpointIO must snapshot the checkpoint (clone tensors) on the main thread " + "to avoid races with parameter mutation; got mutated value instead" + )