Skip to content

Conversation

littlebullGit
Copy link
Contributor

@littlebullGit littlebullGit commented Aug 17, 2025

Fix: AsyncCheckpointIO snapshots tensors to avoid race with parameter mutation

Fixes #20953

Summary

  • Root cause: Background thread serialized live tensor references; the training thread mutated tensors after scheduling the async save, leading to mixed-step checkpoints.
  • Fix: Snapshot all tensors on the main thread before submitting the async save using apply_to_collection(..., torch.Tensor, lambda t: t.detach().clone()).

Implementation

  • Reproduce the issue in a unit test.
  • Clone all tensors in the checkpoint payload on the caller thread to take a point-in-time snapshot.
  • Supports both positional and keyword checkpoint parameters.
  • Preserves non-tensor values; handles nested containers.
  • Continues to surface background exceptions on teardown.

📚 Documentation preview 📚: https://pytorch-lightning--21079.org.readthedocs.build/en/21079/

… mutation

Summary
- Root cause: Background thread serialized live tensor references; the training
  thread mutated tensors after scheduling the async save, leading to mixed-step
  checkpoints.
- Fix: Snapshot all tensors on the main thread before submitting the async save
  using `apply_to_collection(..., torch.Tensor, lambda t: t.detach().clone())`.

Implementation
- Reproduce the issue in unit test
- Clone all tensors in the checkpoint payload on the caller thread to take a
  point-in-time snapshot.
- Supports both positional and keyword `checkpoint` parameters.
- Preserves non-tensor values; handles nested containers.
- Continues to surface background exceptions on teardown.
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Aug 17, 2025
@Borda Borda merged commit 2c74bee into Lightning-AI:master Aug 18, 2025
84 of 85 checks passed
Borda added a commit that referenced this pull request Aug 28, 2025
… mutation (#21079)

* Fix: AsyncCheckpointIO snapshots tensors to avoid race with parameter mutation

Summary
- Root cause: Background thread serialized live tensor references; the training
  thread mutated tensors after scheduling the async save, leading to mixed-step
  checkpoints.
- Fix: Snapshot all tensors on the main thread before submitting the async save
  using `apply_to_collection(..., torch.Tensor, lambda t: t.detach().clone())`.

Implementation
- Reproduce the issue in unit test
- Clone all tensors in the checkpoint payload on the caller thread to take a
  point-in-time snapshot.
- Supports both positional and keyword `checkpoint` parameters.
- Preserves non-tensor values; handles nested containers.
- Continues to surface background exceptions on teardown.

* chlog

---------

Co-authored-by: Jirka B <[email protected]>
(cherry picked from commit 2c74bee)
lantiga pushed a commit that referenced this pull request Aug 29, 2025
… mutation (#21079)

* Fix: AsyncCheckpointIO snapshots tensors to avoid race with parameter mutation

Summary
- Root cause: Background thread serialized live tensor references; the training
  thread mutated tensors after scheduling the async save, leading to mixed-step
  checkpoints.
- Fix: Snapshot all tensors on the main thread before submitting the async save
  using `apply_to_collection(..., torch.Tensor, lambda t: t.detach().clone())`.

Implementation
- Reproduce the issue in unit test
- Clone all tensors in the checkpoint payload on the caller thread to take a
  point-in-time snapshot.
- Supports both positional and keyword `checkpoint` parameters.
- Preserves non-tensor values; handles nested containers.
- Continues to surface background exceptions on teardown.

* chlog

---------

Co-authored-by: Jirka B <[email protected]>
(cherry picked from commit 2c74bee)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

async checkpointing unsafe?

2 participants