Skip to content

Commit dd3e717

Browse files
teja-raopytorchmergebot
authored andcommitted
Add async checkpointing impl to experimental checkpointer and add a builder API (pytorch#156927)
1. Adds an AsyncCheckpointer with out-of-process checkpointing and state_dict_stager with shared memory, pinned memory and Zero Overhead Support. 2. Adds two conveinient functions to create sync/async checkpointers Differential Revision: [D77336833](https://our.internmc.facebook.com/intern/diff/D77336833/) Pull Request resolved: pytorch#156927 Approved by: https://github.com/pradeepfn
1 parent 7081b82 commit dd3e717

File tree

12 files changed

+1890
-0
lines changed

12 files changed

+1890
-0
lines changed
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Owner(s): ["oncall: distributed checkpointing"]
2+
3+
import os
4+
import shutil
5+
import tempfile
6+
7+
import torch
8+
from torch.distributed.checkpoint._experimental.barriers import BarrierConfig
9+
from torch.distributed.checkpoint._experimental.builder import (
10+
make_async_checkpointer,
11+
make_sync_checkpointer,
12+
)
13+
from torch.distributed.checkpoint._experimental.checkpointer import (
14+
AsyncCheckpointer,
15+
SyncCheckpointer,
16+
)
17+
from torch.distributed.checkpoint._experimental.config import CheckpointerConfig
18+
from torch.distributed.checkpoint._experimental.staging import CheckpointStagerConfig
19+
from torch.distributed.checkpoint._experimental.types import RankInfo
20+
from torch.testing._internal.common_utils import run_tests, TestCase
21+
22+
23+
class TestMakeCheckpointer(TestCase):
24+
def setUp(self) -> None:
25+
# Create a temporary directory for checkpoints
26+
self.temp_dir = tempfile.mkdtemp()
27+
28+
# Create real objects for testing
29+
self.rank_info = RankInfo(
30+
global_world_size=1,
31+
global_rank=0,
32+
)
33+
34+
# Create a test state dictionary
35+
self.state_dict = {
36+
"model": torch.nn.Linear(10, 5).state_dict(),
37+
"optimizer": {"param_groups": [{"lr": 0.01}]},
38+
"epoch": 5,
39+
"step": 1000,
40+
}
41+
42+
def tearDown(self) -> None:
43+
# Clean up the temporary directory
44+
shutil.rmtree(self.temp_dir)
45+
46+
def test_make_sync_checkpointer(self) -> None:
47+
"""Test creating a synchronous checkpointer using make_sync_checkpointer."""
48+
49+
# Create sync checkpointer using factory function with no barrier
50+
config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None))
51+
checkpointer = make_sync_checkpointer(config=config, rank_info=self.rank_info)
52+
53+
# Verify it's a SyncCheckpointer instance
54+
self.assertIsInstance(checkpointer, SyncCheckpointer)
55+
56+
# Test that it works for sync operations
57+
checkpoint_path = os.path.join(self.temp_dir, "checkpoint_factory_sync")
58+
result = checkpointer.save(self.state_dict, checkpoint_path)
59+
self.assertIsNone(result) # Sync mode returns None
60+
61+
# Verify checkpoint was created
62+
checkpoint_file = os.path.join(
63+
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
64+
)
65+
self.assertTrue(os.path.exists(checkpoint_file))
66+
67+
# Test loading
68+
loaded_state_dict = checkpointer.load(checkpoint_path)
69+
self.assertEqual(loaded_state_dict["epoch"], 5)
70+
71+
def test_make_sync_checkpointer_with_config_first(self) -> None:
72+
"""Test creating a synchronous checkpointer with config as first parameter."""
73+
# Create sync checkpointer with config as first parameter
74+
config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None))
75+
checkpointer = make_sync_checkpointer(config=config, rank_info=self.rank_info)
76+
77+
# Verify it's a SyncCheckpointer instance
78+
self.assertIsInstance(checkpointer, SyncCheckpointer)
79+
80+
# Test that it works for sync operations
81+
checkpoint_path = os.path.join(
82+
self.temp_dir, "checkpoint_factory_sync_config_first"
83+
)
84+
result = checkpointer.save(self.state_dict, checkpoint_path)
85+
self.assertIsNone(result) # Sync mode returns None
86+
87+
# Verify checkpoint was created
88+
checkpoint_file = os.path.join(
89+
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
90+
)
91+
self.assertTrue(os.path.exists(checkpoint_file))
92+
93+
def test_make_sync_checkpointer_with_custom_config(self) -> None:
94+
"""Test creating a synchronous checkpointer with a custom config."""
95+
# Create a custom config with no barrier
96+
config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None))
97+
98+
# Create sync checkpointer with the custom config
99+
checkpointer = make_sync_checkpointer(rank_info=self.rank_info, config=config)
100+
101+
# Verify it's a SyncCheckpointer instance
102+
self.assertIsInstance(checkpointer, SyncCheckpointer)
103+
104+
# Test that it works for sync operations
105+
checkpoint_path = os.path.join(
106+
self.temp_dir, "checkpoint_factory_sync_custom_config"
107+
)
108+
result = checkpointer.save(self.state_dict, checkpoint_path)
109+
self.assertIsNone(result) # Sync mode returns None
110+
111+
# Verify checkpoint was created
112+
checkpoint_file = os.path.join(
113+
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
114+
)
115+
self.assertTrue(os.path.exists(checkpoint_file))
116+
117+
# Test loading
118+
loaded_state_dict = checkpointer.load(checkpoint_path)
119+
self.assertEqual(loaded_state_dict["epoch"], 5)
120+
121+
def test_make_async_checkpointer(self) -> None:
122+
"""Test creating an asynchronous checkpointer using make_async_checkpointer."""
123+
# Create async checkpointer using factory function with default parameters
124+
config: CheckpointerConfig = CheckpointerConfig()
125+
config.staging_config = CheckpointStagerConfig(
126+
use_cuda_non_blocking_copy=torch.cuda.is_available(),
127+
use_pinned_memory=torch.cuda.is_available(),
128+
)
129+
checkpointer = make_async_checkpointer(config=config, rank_info=self.rank_info)
130+
131+
try:
132+
# Verify it's an AsyncCheckpointer instance
133+
self.assertIsInstance(checkpointer, AsyncCheckpointer)
134+
135+
# Test that it works for async operations
136+
checkpoint_path = os.path.join(self.temp_dir, "checkpoint_factory_async")
137+
stage_future, write_future = checkpointer.save(
138+
self.state_dict, checkpoint_path
139+
)
140+
141+
# Verify futures are returned
142+
self.assertIsNotNone(stage_future)
143+
self.assertIsNotNone(write_future)
144+
145+
# Wait for completion
146+
stage_future.result()
147+
write_future.result()
148+
149+
# Verify checkpoint was created
150+
checkpoint_file = os.path.join(
151+
checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt"
152+
)
153+
self.assertTrue(os.path.exists(checkpoint_file))
154+
155+
# Test loading
156+
loaded_state_dict = checkpointer.load(checkpoint_path)
157+
self.assertEqual(loaded_state_dict["epoch"], 5)
158+
159+
finally:
160+
# Clean up
161+
checkpointer.close()
162+
163+
164+
if __name__ == "__main__":
165+
run_tests()

0 commit comments

Comments
 (0)