Skip to content

Commit 141aafc

Browse files
muddyfishIsaevIlya
authored andcommitted
Add examples for reading/writing checkpoints with lightning in various ways (#150)
* Add examples for reading/writing checkpoints with lightning in various ways * Add small set of documentation for lightning support to the README --------- Co-authored-by: Simon Beal <[email protected]>
1 parent 860a9d8 commit 141aafc

File tree

5 files changed

+190
-0
lines changed

5 files changed

+190
-0
lines changed

README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,40 @@ For example, assuming the following directory bucket name `my-test-bucket--usw2-
115115
usw2-az1, then the URI used will look like: `s3://my-test-bucket--usw2-az1--x-s3/<PREFIX>` (**please note that the
116116
prefix for Amazon S3 Express One Zone should end with '/'**), paired with region us-west-2.
117117

118+
119+
## Lightning Integration
120+
121+
Amazon S3 Connector for PyTorch includes an integration for PyTorch Lightning, featuring S3LightningCheckpoint, an
122+
implementation of Lightning's CheckpointIO. This allows users to make use of Amazon S3 Connector for PyTorch's S3
123+
checkpointing functionality with Pytorch Lightning.
124+
125+
### Getting Started
126+
127+
#### Installation
128+
129+
```sh
130+
pip install s3torchconnector[lightning]
131+
```
132+
133+
### Examples
134+
135+
End to end examples for the Pytorch Lightning integration can be found in the
136+
[examples/lightning](https://github.com/awslabs/s3-connector-for-pytorch/tree/main/examples/lightning) directory
137+
138+
```py
139+
from lightning import Trainer
140+
from s3torchconnector.lightning import S3LightningCheckpoint
141+
142+
...
143+
144+
s3_checkpoint_io = S3LightningCheckpoint("us-east-1")
145+
trainer = Trainer(
146+
plugins=[s3_checkpoint_io],
147+
default_root_dir="s3://bucket_name/key_prefix/"
148+
)
149+
trainer.fit(model)
150+
```
151+
118152
## Contributing
119153
We welcome contributions to Amazon S3 Connector for PyTorch. Please
120154
see [CONTRIBUTING](https://github.com/awslabs/s3-connector-for-pytorch/blob/main/doc/CONTRIBUTING.md)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
from lightning import Trainer
5+
from lightning.pytorch.callbacks import ModelCheckpoint
6+
from lightning.pytorch.demos import WikiText2, LightningTransformer
7+
from lightning.pytorch.plugins import AsyncCheckpointIO
8+
from torch.utils.data import DataLoader
9+
10+
from s3torchconnector.lightning import S3LightningCheckpoint
11+
12+
13+
def main(region: str, checkpoint_path: str):
14+
dataset = WikiText2()
15+
dataloader = DataLoader(dataset, num_workers=3)
16+
17+
model = LightningTransformer(vocab_size=dataset.vocab_size)
18+
s3_lightning_checkpoint = S3LightningCheckpoint(region)
19+
async_checkpoint = AsyncCheckpointIO(s3_lightning_checkpoint)
20+
21+
# This will create one checkpoint per 'step', which we define later to be 8.
22+
# To checkpoint more or less often, change `every_n_train_steps`.
23+
checkpoint_callback = ModelCheckpoint(
24+
dirpath=checkpoint_path,
25+
save_top_k=-1,
26+
every_n_train_steps=1,
27+
filename="checkpoint-{epoch:02d}-{step:02d}",
28+
enable_version_counter=True,
29+
)
30+
31+
trainer = Trainer(
32+
plugins=[async_checkpoint],
33+
callbacks=[checkpoint_callback],
34+
min_epochs=4,
35+
max_epochs=8,
36+
max_steps=8,
37+
)
38+
trainer.fit(model, dataloader)
39+
40+
41+
if __name__ == "__main__":
42+
import os
43+
44+
main(os.getenv("REGION"), os.getenv("CHECKPOINT_PATH"))
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
from lightning import Trainer
5+
from lightning.pytorch.demos import WikiText2, LightningTransformer
6+
from torch.utils.data import DataLoader
7+
8+
from s3torchconnector.lightning import S3LightningCheckpoint
9+
10+
11+
def main(region: str, checkpoint_path: str):
12+
dataset = WikiText2()
13+
dataloader = DataLoader(dataset, num_workers=3)
14+
15+
model = LightningTransformer(vocab_size=dataset.vocab_size)
16+
s3_lightning_checkpoint = S3LightningCheckpoint(region)
17+
18+
# No automatic checkpointing set up here.
19+
trainer = Trainer(
20+
plugins=[s3_lightning_checkpoint],
21+
enable_checkpointing=False,
22+
min_epochs=4,
23+
max_epochs=5,
24+
max_steps=3,
25+
)
26+
trainer.fit(model, dataloader)
27+
# Manually create checkpoint to the desired location
28+
trainer.save_checkpoint(checkpoint_path)
29+
30+
31+
if __name__ == "__main__":
32+
import os
33+
34+
main(os.getenv("REGION"), os.getenv("CHECKPOINT_PATH"))
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
from lightning import Trainer
5+
from lightning.pytorch.demos import WikiText2, LightningTransformer
6+
from torch.utils.data import DataLoader
7+
8+
from s3torchconnector.lightning import S3LightningCheckpoint
9+
10+
11+
def main(region: str, checkpoint_path: str):
12+
dataset = WikiText2()
13+
dataloader = DataLoader(dataset, num_workers=3)
14+
15+
model = LightningTransformer(vocab_size=dataset.vocab_size)
16+
s3_lightning_checkpoint = S3LightningCheckpoint(region)
17+
18+
trainer = Trainer(
19+
plugins=[s3_lightning_checkpoint],
20+
min_epochs=4,
21+
max_epochs=5,
22+
max_steps=3,
23+
)
24+
# Load the checkpoint in `ckpt_path` before training
25+
trainer.fit(model, dataloader, ckpt_path=checkpoint_path)
26+
27+
28+
if __name__ == "__main__":
29+
import os
30+
31+
main(os.getenv("REGION"), os.getenv("CHECKPOINT_PATH"))
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
from lightning import Trainer
5+
from lightning.pytorch.callbacks import ModelCheckpoint
6+
from lightning.pytorch.demos import WikiText2, LightningTransformer
7+
from torch.utils.data import DataLoader
8+
9+
from s3torchconnector.lightning import S3LightningCheckpoint
10+
11+
12+
def main(region: str, checkpoint_path: str, save_only_latest: bool):
13+
dataset = WikiText2()
14+
dataloader = DataLoader(dataset, num_workers=3)
15+
16+
model = LightningTransformer(vocab_size=dataset.vocab_size)
17+
s3_lightning_checkpoint = S3LightningCheckpoint(region)
18+
19+
# Save once per step, and if `save_only_latest`, replace the last checkpoint each time.
20+
# Replacing is implemented by saving the new checkpoint, and then deleting the previous one.
21+
# If `save_only_latest` is False, a new checkpoint is created for each step.
22+
checkpoint_callback = ModelCheckpoint(
23+
dirpath=checkpoint_path,
24+
save_top_k=1 if save_only_latest else -1,
25+
every_n_train_steps=1,
26+
filename="checkpoint-{epoch:02d}-{step:02d}",
27+
enable_version_counter=True,
28+
)
29+
30+
trainer = Trainer(
31+
plugins=[s3_lightning_checkpoint],
32+
callbacks=[checkpoint_callback],
33+
min_epochs=4,
34+
max_epochs=5,
35+
max_steps=3,
36+
)
37+
trainer.fit(model, dataloader)
38+
39+
40+
if __name__ == "__main__":
41+
import os
42+
43+
main(
44+
os.getenv("REGION"),
45+
os.getenv("CHECKPOINT_PATH"),
46+
os.getenv("LATEST_CHECKPOINT_ONLY") == "1",
47+
)

0 commit comments

Comments
 (0)