Skip to content

Commit 6e70fee

Browse files
dnanutiIsaevIlya
authored andcommitted
Update User-Agent for Lightning checkpoints (#165)
* Update User-Agent for Lightning checkpoints * Update CHANGELOG.md
1 parent 141aafc commit 6e70fee

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### Bug Fixes / Improvements
66
* Fix deadlock when enabling CRT debug logs. Removed former experimental method _enable_debug_logging().
77
* Refactor User-Agent setup for extensibility.
8+
* Update lightning User-Agent prefix to `s3torchconnector/{__version__} (lightning; {lightning.__version__}`.
89

910
## v1.1.4 (February 26, 2024)
1011

s3torchconnector/src/s3torchconnector/lightning/s3_lightning_checkpoint.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,23 @@
33

44
from typing import Optional, Dict, Any
55

6+
import lightning
67
import torch
78

89
from lightning.pytorch.plugins.io import CheckpointIO
910

1011
from .._s3client import S3Client
1112
from .._s3dataset_common import parse_s3_uri
13+
from .._user_agent import UserAgent
1214

1315

1416
class S3LightningCheckpoint(CheckpointIO):
1517
"""A checkpoint manager for S3 using the :class:`CheckpointIO` interface."""
1618

1719
def __init__(self, region: str):
1820
self.region = region
19-
self._client = S3Client(region)
21+
user_agent = UserAgent(["lightning", lightning.__version__])
22+
self._client = S3Client(region, user_agent=user_agent)
2023

2124
def save_checkpoint(
2225
self,

s3torchconnector/tst/e2e/test_e2e_s3_lightning_checkpoint.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
3+
import lightning
34
import pytest
45
import random
56
import torch
@@ -15,7 +16,7 @@
1516
from s3torchconnector._s3client import S3Client
1617
from s3torchconnector._s3dataset_common import parse_s3_uri
1718
from s3torchconnector.lightning import S3LightningCheckpoint
18-
from s3torchconnectorclient import S3Exception
19+
from s3torchconnectorclient import S3Exception, __version__
1920

2021
from models.net import Net
2122
from models.lightning_transformer import LightningTransformer, L
@@ -24,6 +25,7 @@
2425
def test_save_and_load_checkpoint(checkpoint_directory):
2526
tensor = torch.rand(3, 10, 10)
2627
s3_lightning_checkpoint = S3LightningCheckpoint(region=checkpoint_directory.region)
28+
_verify_user_agent(s3_lightning_checkpoint)
2729
checkpoint_name = "lightning_checkpoint.ckpt"
2830
s3_uri = f"{checkpoint_directory.s3_uri}{checkpoint_name}"
2931
s3_lightning_checkpoint.save_checkpoint(tensor, s3_uri)
@@ -38,17 +40,19 @@ def test_load_compatibility_with_s3_checkpoint(checkpoint_directory):
3840
s3_uri = f"{checkpoint_directory.s3_uri}{checkpoint_name}"
3941
with checkpoint.writer(s3_uri) as writer:
4042
torch.save(tensor, writer)
41-
lightning_checkpoint = S3LightningCheckpoint(region=checkpoint_directory.region)
42-
loaded_checkpoint = lightning_checkpoint.load_checkpoint(s3_uri)
43+
s3_lightning_checkpoint = S3LightningCheckpoint(region=checkpoint_directory.region)
44+
_verify_user_agent(s3_lightning_checkpoint)
45+
loaded_checkpoint = s3_lightning_checkpoint.load_checkpoint(s3_uri)
4346
assert torch.equal(tensor, loaded_checkpoint)
4447

4548

4649
def test_save_compatibility_with_s3_checkpoint(checkpoint_directory):
4750
tensor = torch.rand(3, 10, 10)
4851
checkpoint_name = "lightning_checkpoint.ckpt"
49-
lightning_checkpoint = S3LightningCheckpoint(region=checkpoint_directory.region)
52+
s3_lightning_checkpoint = S3LightningCheckpoint(region=checkpoint_directory.region)
53+
_verify_user_agent(s3_lightning_checkpoint)
5054
s3_uri = f"{checkpoint_directory.s3_uri}{checkpoint_name}"
51-
lightning_checkpoint.save_checkpoint(tensor, s3_uri)
55+
s3_lightning_checkpoint.save_checkpoint(tensor, s3_uri)
5256
checkpoint = S3Checkpoint(region=checkpoint_directory.region)
5357
loaded_checkpoint = torch.load(checkpoint.reader(s3_uri))
5458
assert torch.equal(tensor, loaded_checkpoint)
@@ -57,14 +61,15 @@ def test_save_compatibility_with_s3_checkpoint(checkpoint_directory):
5761
def test_delete_checkpoint(checkpoint_directory):
5862
tensor = torch.rand(3, 10, 10)
5963
checkpoint_name = "lightning_checkpoint.ckpt"
60-
lightning_checkpoint = S3LightningCheckpoint(region=checkpoint_directory.region)
64+
s3_lightning_checkpoint = S3LightningCheckpoint(region=checkpoint_directory.region)
65+
_verify_user_agent(s3_lightning_checkpoint)
6166
s3_uri = f"{checkpoint_directory.s3_uri}{checkpoint_name}"
62-
lightning_checkpoint.save_checkpoint(tensor, s3_uri)
63-
loaded_checkpoint = lightning_checkpoint.load_checkpoint(s3_uri)
67+
s3_lightning_checkpoint.save_checkpoint(tensor, s3_uri)
68+
loaded_checkpoint = s3_lightning_checkpoint.load_checkpoint(s3_uri)
6469
assert torch.equal(tensor, loaded_checkpoint)
65-
lightning_checkpoint.remove_checkpoint(s3_uri)
70+
s3_lightning_checkpoint.remove_checkpoint(s3_uri)
6671
with pytest.raises(S3Exception, match="Service error: The key does not exist"):
67-
lightning_checkpoint.load_checkpoint(s3_uri)
72+
s3_lightning_checkpoint.load_checkpoint(s3_uri)
6873

6974

7075
def test_load_trained_checkpoint(checkpoint_directory):
@@ -78,6 +83,7 @@ def test_load_trained_checkpoint(checkpoint_directory):
7883
s3_uri = f"{checkpoint_directory.s3_uri}{checkpoint_name}"
7984
trainer.save_checkpoint(s3_uri)
8085
s3_lightning_checkpoint = S3LightningCheckpoint(region=checkpoint_directory.region)
86+
_verify_user_agent(s3_lightning_checkpoint)
8187
loaded_checkpoint = s3_lightning_checkpoint.load_checkpoint(s3_uri)
8288
_verify_equal_state_dict(model.state_dict(), loaded_checkpoint["state_dict"])
8389

@@ -88,6 +94,7 @@ def test_compatibility_with_trainer_plugins(checkpoint_directory):
8894
dataloader = DataLoader(dataset, num_workers=3)
8995
model = LightningTransformer(vocab_size=dataset.vocab_size)
9096
s3_lightning_checkpoint = S3LightningCheckpoint(region=checkpoint_directory.region)
97+
_verify_user_agent(s3_lightning_checkpoint)
9198
trainer = L.Trainer(
9299
default_root_dir=checkpoint_directory.s3_uri,
93100
plugins=[s3_lightning_checkpoint],
@@ -113,6 +120,7 @@ def test_compatibility_with_checkpoint_callback(checkpoint_directory):
113120

114121
model = LightningTransformer(vocab_size=dataset.vocab_size)
115122
s3_lightning_checkpoint = S3LightningCheckpoint(checkpoint_directory.region)
123+
_verify_user_agent(s3_lightning_checkpoint)
116124

117125
checkpoint_callback = ModelCheckpoint(
118126
dirpath=checkpoint_directory.s3_uri,
@@ -140,6 +148,7 @@ def test_compatibility_with_checkpoint_callback(checkpoint_directory):
140148

141149
checkpoint_s3_uri = f"{checkpoint_directory.s3_uri}{expected_checkpoint_name}"
142150
loaded_checkpoint = s3_lightning_checkpoint.load_checkpoint(checkpoint_s3_uri)
151+
_verify_user_agent(s3_lightning_checkpoint)
143152
_verify_equal_state_dict(model.state_dict(), loaded_checkpoint["state_dict"])
144153

145154

@@ -150,6 +159,7 @@ def test_compatibility_with_async_checkpoint_io(checkpoint_directory):
150159

151160
model = LightningTransformer(vocab_size=dataset.vocab_size)
152161
s3_lightning_checkpoint = S3LightningCheckpoint(checkpoint_directory.region)
162+
_verify_user_agent(s3_lightning_checkpoint)
153163
async_s3_lightning_checkpoint = AsyncCheckpointIO(s3_lightning_checkpoint)
154164

155165
trainer = L.Trainer(
@@ -168,6 +178,7 @@ def test_compatibility_with_async_checkpoint_io(checkpoint_directory):
168178
checkpoint_key = "lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt"
169179
checkpoint_s3_uri = f"{checkpoint_directory.s3_uri}{checkpoint_key}"
170180
loaded_checkpoint = s3_lightning_checkpoint.load_checkpoint(checkpoint_s3_uri)
181+
_verify_user_agent(s3_lightning_checkpoint)
171182
_verify_equal_state_dict(model.state_dict(), loaded_checkpoint["state_dict"])
172183

173184

@@ -225,6 +236,7 @@ def test_nn_checkpointing(checkpoint_directory):
225236
# Assert that eval and train do not raise
226237
loaded_nn_model.eval()
227238
loaded_nn_model.train()
239+
_verify_user_agent(s3_lightning_checkpoint)
228240

229241

230242
def _verify_equal_state_dict(
@@ -236,3 +248,10 @@ def _verify_equal_state_dict(
236248
# These are tuples (str, Tensor)
237249
assert model_key == loaded_key
238250
assert torch.equal(model_value, loaded_value)
251+
252+
253+
def _verify_user_agent(s3_lightning_checkpoint: S3LightningCheckpoint):
254+
expected_user_agent = (
255+
f"s3torchconnector/{__version__} (lightning; {lightning.__version__})"
256+
)
257+
assert s3_lightning_checkpoint._client.user_agent_prefix == expected_user_agent

0 commit comments

Comments
 (0)