Skip to content

Commit 8c9aabe

Browse files
committed
Update documentation for supporting DCP (#261)
Update documentation for supporting DCP
1 parent c5a44a2 commit 8c9aabe

File tree

2 files changed

+218
-0
lines changed

2 files changed

+218
-0
lines changed

README.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,63 @@ For example, assuming the following directory bucket name `my-test-bucket--usw2-
114114
usw2-az1, then the URI used will look like: `s3://my-test-bucket--usw2-az1--x-s3/<PREFIX>` (**please note that the
115115
prefix for Amazon S3 Express One Zone should end with '/'**), paired with region us-west-2.
116116

117+
## Distributed checkpoints
118+
119+
### Overview
120+
121+
Amazon S3 Connector for PyTorch provides robust support for PyTorch distributed checkpoints. This feature includes:
122+
123+
- `S3StorageWriter` and `S3StorageReader`: Implementations of PyTorch's StorageWriter and StorageReader interfaces.
124+
- `S3FileSystem`: An implementation of PyTorch's FileSystemBase.
125+
126+
These tools enable seamless integration of Amazon S3 with
127+
[PyTorch Distributed Checkpoints](https://pytorch.org/docs/stable/distributed.checkpoint.html),
128+
allowing efficient storage and retrieval of distributed model checkpoints.
129+
130+
### Prerequisites and Installation
131+
132+
PyTorch 2.3 or newer is required.
133+
To use the distributed checkpoints feature, install S3 Connector for PyTorch with the `dcp` extra:
134+
135+
```sh
136+
pip install s3torchconnector[dcp]
137+
```
138+
139+
### Sample Example
140+
141+
End-to-end examples for using distributed checkpoints with S3 Connector for PyTorch
142+
can be found in the [examples/dcp](examples/dcp) directory.
143+
144+
```py
145+
from s3torchconnector import S3StorageWriter, S3StorageReader
146+
147+
import torchvision
148+
import torch.distributed.checkpoint as DCP
149+
150+
# Configuration
151+
CHECKPOINT_URI = "s3://<BUCKET>/<KEY>/"
152+
REGION = "us-east-1"
153+
154+
model = torchvision.models.resnet18()
155+
156+
# Save distributed checkpoint to S3
157+
s3_storage_writer = S3StorageWriter(region=REGION, path=CHECKPOINT_URI)
158+
DCP.save(
159+
state_dict=model.state_dict,
160+
storage_writer=s3_storage_writer,
161+
)
162+
163+
# Load distributed checkpoint from S3
164+
model = torchvision.models.resnet18()
165+
model_state_dict = model.state_dict()
166+
s3_storage_reader = S3StorageReader(region=REGION, path=CHECKPOINT_URI)
167+
DCP.load(
168+
state_dict=model_state_dict,
169+
storage_reader=s3_storage_reader,
170+
)
171+
model.load_state_dict(model_state_dict)
172+
```
173+
117174
## Parallel/Distributed Training
118175

119176
Amazon S3 Connector for PyTorch provides support for parallel and distributed training with PyTorch,

examples/dcp/stateful_example.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
# inspired by https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/examples/stateful_example.py
5+
6+
import os
7+
8+
import torch
9+
import torch.distributed as dist
10+
import torch.distributed.checkpoint as dcp
11+
import torch.multiprocessing as mp
12+
import torch.nn as nn
13+
from torch.distributed.checkpoint.state_dict import (
14+
_patch_model_state_dict,
15+
_patch_optimizer_state_dict,
16+
)
17+
from torch.distributed.device_mesh import init_device_mesh
18+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
19+
20+
from s3torchconnector.dcp import S3StorageWriter, S3StorageReader
21+
22+
23+
class Model(torch.nn.Module):
24+
def __init__(self) -> None:
25+
super().__init__()
26+
torch.manual_seed(0)
27+
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
28+
self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
29+
self.net3 = nn.Linear(32, 64)
30+
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
31+
32+
def forward(self, x):
33+
return self.net4(self.net3(self.net2(self.net1(x))))
34+
35+
def get_input(self):
36+
return torch.rand(8, 8, device="cuda")
37+
38+
39+
def _make_stateful(model, optim):
40+
_patch_model_state_dict(model)
41+
_patch_optimizer_state_dict(model, optimizers=optim)
42+
43+
44+
def _train(model, optim, train_steps=1):
45+
torch.manual_seed(0)
46+
loss = None
47+
for _ in range(train_steps):
48+
loss = model(model.get_input()).sum()
49+
loss.backward()
50+
optim.step()
51+
optim.zero_grad()
52+
53+
return loss
54+
55+
56+
def _init_model(device, world_size):
57+
device_mesh = init_device_mesh(device, (world_size,))
58+
model = Model().cuda()
59+
model = FSDP(
60+
model,
61+
device_mesh=device_mesh,
62+
use_orig_params=True,
63+
)
64+
optim = torch.optim.Adam(model.parameters(), lr=0.1)
65+
_make_stateful(model, optim)
66+
67+
return model, optim
68+
69+
70+
def _compare_models(model1, model2, rank, rtol=1e-5, atol=1e-8):
71+
model1.eval()
72+
model2.eval()
73+
74+
with FSDP.summon_full_params(model1), FSDP.summon_full_params(model2):
75+
for (name1, param1), (name2, param2) in zip(
76+
model1.named_parameters(), model2.named_parameters()
77+
):
78+
if name1 != name2:
79+
print(f"Parameter names don't match: {name1} vs {name2}. Rank:{rank}")
80+
return False
81+
82+
if not torch.allclose(param1, param2, rtol=rtol, atol=atol):
83+
print(f"Parameters don't match for {name1}. Rank:{rank}")
84+
print(
85+
f"Max difference: {(param1 - param2).abs().max().item()}. Rank:{rank}"
86+
)
87+
return False
88+
89+
print(f"All parameters match within the specified tolerance. Rank:{rank}")
90+
return True
91+
92+
93+
def _setup(rank, world_size):
94+
# Set up world process group
95+
os.environ["MASTER_ADDR"] = "localhost"
96+
os.environ["MASTER_PORT"] = "12355"
97+
dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size)
98+
torch.cuda.set_device(rank)
99+
100+
101+
def _train_initial_model(device, rank, world_size):
102+
print(f"Train initial model on rank:{rank}")
103+
model, optim = _init_model(device, world_size)
104+
_train(model, optim, train_steps=2)
105+
return model, optim
106+
107+
108+
def _train_model_to_different_state(device, model, rank, world_size):
109+
print(f"Train another model on rank:{rank}")
110+
loaded_model, loaded_optim = _init_model(device, world_size)
111+
_train(loaded_model, loaded_optim, train_steps=4)
112+
print(f"Check that models are different on rank:{rank}")
113+
assert not _compare_models(model, loaded_model, rank)
114+
return loaded_model, loaded_optim
115+
116+
117+
def _continue_training_loaded_model(loaded_model, loaded_optim, model, rank):
118+
print(f"Check that loaded model and original model are the same on rank:{rank}")
119+
assert _compare_models(model, loaded_model, rank)
120+
print(f"Train loaded model on rank:{rank}")
121+
_train(loaded_model, loaded_optim, train_steps=2)
122+
123+
124+
def run(rank, world_size, region, s3_uri, device="cuda"):
125+
_setup(rank, world_size)
126+
model, optim = _train_initial_model(device, rank, world_size)
127+
128+
print(f"Saving checkpoint on rank:{rank}")
129+
# initialize S3StorageWriter with region and bucket name, before passing to dcp.save as writer
130+
storage_writer = S3StorageWriter(region, s3_uri)
131+
dcp.save(
132+
state_dict={"model": model, "optimizer": optim},
133+
storage_writer=storage_writer,
134+
)
135+
136+
# presumably do something else and decided to return to previous version of model
137+
modified_model, modified_optim = _train_model_to_different_state(
138+
device, model, rank, world_size
139+
)
140+
print(f"Load previously saved checkpoint on rank:{rank}")
141+
# initialize S3StorageReader with region and bucket name, before passing to dcp.load as reader
142+
storage_reader = S3StorageReader(region, s3_uri)
143+
dcp.load(
144+
state_dict={"model": modified_model, "optimizer": modified_optim},
145+
storage_reader=storage_reader,
146+
)
147+
_continue_training_loaded_model(modified_model, modified_optim, model, rank)
148+
print(f"Quiting on rank:{rank}")
149+
150+
151+
if __name__ == "__main__":
152+
world_size = torch.cuda.device_count()
153+
region = os.getenv("REGION")
154+
s3_uri = os.getenv("CHECKPOINT_PATH")
155+
print(f"Running stateful checkpoint example on {world_size} devices.")
156+
mp.spawn(
157+
run,
158+
args=(world_size, region, s3_uri),
159+
nprocs=world_size,
160+
join=True,
161+
)

0 commit comments

Comments
 (0)