Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/lightning/pytorch/plugins/io/async_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from typing_extensions import override

from lightning.fabric.plugins import CheckpointIO
Expand All @@ -41,6 +43,17 @@ def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None:
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
"""Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``."""

# snapshot the checkpoint payload on the caller thread to avoid races with parameter mutation
def _clone_tensor(t: torch.Tensor) -> torch.Tensor:
# detach to avoid autograd history and clone to take a point-in-time copy
return t.detach().clone()

# rebuild args/kwargs with a cloned checkpoint (supports positional or kw form)
if "checkpoint" in kwargs:
kwargs = {**kwargs, "checkpoint": apply_to_collection(kwargs["checkpoint"], torch.Tensor, _clone_tensor)}
elif len(args) >= 1:
args = (apply_to_collection(args[0], torch.Tensor, _clone_tensor), *args[1:])

def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
try:
assert self.checkpoint_io is not None
Expand Down
53 changes: 53 additions & 0 deletions tests/tests_pytorch/plugins/test_async_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import time
from typing import Any, Optional

import pytest
import torch

from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO


class _CaptureCheckpointIO(CheckpointIO):
def __init__(self) -> None:
self.saved: Optional[dict[str, Any]] = None

def save_checkpoint(self, checkpoint: dict[str, Any], path: str, storage_options: Optional[Any] = None) -> None:
# Simulate some delay to increase race window
time.sleep(0.05)
# Store the received checkpoint object (not a deep copy) to inspect tensor values
self.saved = checkpoint

def load_checkpoint(self, path: str, map_location: Optional[Any] = None) -> dict[str, Any]:
raise NotImplementedError

def remove_checkpoint(self, path: str) -> None:
pass


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_async_checkpoint_should_snapshot_values_before_mutation():
base = _CaptureCheckpointIO()
async_io = AsyncCheckpointIO(checkpoint_io=base)

# a tensor that we will mutate after scheduling the save
t = torch.tensor([0.0])
ckpt = {"w": t}

# schedule async save
async_io.save_checkpoint(ckpt, path="unused")

# mutate immediately afterward to mimic training thread stepping params
t.add_(1.0)

# ensure background thread finished
async_io.teardown()

assert base.saved is not None, "Async save did not run"

# EXPECTATION: AsyncCheckpointIO should have captured value 0.0 (pre-mutation)
# CURRENT BEHAVIOR (bug): it captures 1.0 because the dict holds references
assert torch.allclose(base.saved["w"], torch.tensor([0.0])), (
"AsyncCheckpointIO must snapshot the checkpoint (clone tensors) on the main thread "
"to avoid races with parameter mutation; got mutated value instead"
)
Loading