11# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22# // SPDX-License-Identifier: BSD
3+ import lightning
34import pytest
45import random
56import torch
1516from s3torchconnector ._s3client import S3Client
1617from s3torchconnector ._s3dataset_common import parse_s3_uri
1718from s3torchconnector .lightning import S3LightningCheckpoint
18- from s3torchconnectorclient import S3Exception
19+ from s3torchconnectorclient import S3Exception , __version__
1920
2021from models .net import Net
2122from models .lightning_transformer import LightningTransformer , L
2425def test_save_and_load_checkpoint (checkpoint_directory ):
2526 tensor = torch .rand (3 , 10 , 10 )
2627 s3_lightning_checkpoint = S3LightningCheckpoint (region = checkpoint_directory .region )
28+ _verify_user_agent (s3_lightning_checkpoint )
2729 checkpoint_name = "lightning_checkpoint.ckpt"
2830 s3_uri = f"{ checkpoint_directory .s3_uri } { checkpoint_name } "
2931 s3_lightning_checkpoint .save_checkpoint (tensor , s3_uri )
@@ -38,17 +40,19 @@ def test_load_compatibility_with_s3_checkpoint(checkpoint_directory):
3840 s3_uri = f"{ checkpoint_directory .s3_uri } { checkpoint_name } "
3941 with checkpoint .writer (s3_uri ) as writer :
4042 torch .save (tensor , writer )
41- lightning_checkpoint = S3LightningCheckpoint (region = checkpoint_directory .region )
42- loaded_checkpoint = lightning_checkpoint .load_checkpoint (s3_uri )
43+ s3_lightning_checkpoint = S3LightningCheckpoint (region = checkpoint_directory .region )
44+ _verify_user_agent (s3_lightning_checkpoint )
45+ loaded_checkpoint = s3_lightning_checkpoint .load_checkpoint (s3_uri )
4346 assert torch .equal (tensor , loaded_checkpoint )
4447
4548
4649def test_save_compatibility_with_s3_checkpoint (checkpoint_directory ):
4750 tensor = torch .rand (3 , 10 , 10 )
4851 checkpoint_name = "lightning_checkpoint.ckpt"
49- lightning_checkpoint = S3LightningCheckpoint (region = checkpoint_directory .region )
52+ s3_lightning_checkpoint = S3LightningCheckpoint (region = checkpoint_directory .region )
53+ _verify_user_agent (s3_lightning_checkpoint )
5054 s3_uri = f"{ checkpoint_directory .s3_uri } { checkpoint_name } "
51- lightning_checkpoint .save_checkpoint (tensor , s3_uri )
55+ s3_lightning_checkpoint .save_checkpoint (tensor , s3_uri )
5256 checkpoint = S3Checkpoint (region = checkpoint_directory .region )
5357 loaded_checkpoint = torch .load (checkpoint .reader (s3_uri ))
5458 assert torch .equal (tensor , loaded_checkpoint )
@@ -57,14 +61,15 @@ def test_save_compatibility_with_s3_checkpoint(checkpoint_directory):
5761def test_delete_checkpoint (checkpoint_directory ):
5862 tensor = torch .rand (3 , 10 , 10 )
5963 checkpoint_name = "lightning_checkpoint.ckpt"
60- lightning_checkpoint = S3LightningCheckpoint (region = checkpoint_directory .region )
64+ s3_lightning_checkpoint = S3LightningCheckpoint (region = checkpoint_directory .region )
65+ _verify_user_agent (s3_lightning_checkpoint )
6166 s3_uri = f"{ checkpoint_directory .s3_uri } { checkpoint_name } "
62- lightning_checkpoint .save_checkpoint (tensor , s3_uri )
63- loaded_checkpoint = lightning_checkpoint .load_checkpoint (s3_uri )
67+ s3_lightning_checkpoint .save_checkpoint (tensor , s3_uri )
68+ loaded_checkpoint = s3_lightning_checkpoint .load_checkpoint (s3_uri )
6469 assert torch .equal (tensor , loaded_checkpoint )
65- lightning_checkpoint .remove_checkpoint (s3_uri )
70+ s3_lightning_checkpoint .remove_checkpoint (s3_uri )
6671 with pytest .raises (S3Exception , match = "Service error: The key does not exist" ):
67- lightning_checkpoint .load_checkpoint (s3_uri )
72+ s3_lightning_checkpoint .load_checkpoint (s3_uri )
6873
6974
7075def test_load_trained_checkpoint (checkpoint_directory ):
@@ -78,6 +83,7 @@ def test_load_trained_checkpoint(checkpoint_directory):
7883 s3_uri = f"{ checkpoint_directory .s3_uri } { checkpoint_name } "
7984 trainer .save_checkpoint (s3_uri )
8085 s3_lightning_checkpoint = S3LightningCheckpoint (region = checkpoint_directory .region )
86+ _verify_user_agent (s3_lightning_checkpoint )
8187 loaded_checkpoint = s3_lightning_checkpoint .load_checkpoint (s3_uri )
8288 _verify_equal_state_dict (model .state_dict (), loaded_checkpoint ["state_dict" ])
8389
@@ -88,6 +94,7 @@ def test_compatibility_with_trainer_plugins(checkpoint_directory):
8894 dataloader = DataLoader (dataset , num_workers = 3 )
8995 model = LightningTransformer (vocab_size = dataset .vocab_size )
9096 s3_lightning_checkpoint = S3LightningCheckpoint (region = checkpoint_directory .region )
97+ _verify_user_agent (s3_lightning_checkpoint )
9198 trainer = L .Trainer (
9299 default_root_dir = checkpoint_directory .s3_uri ,
93100 plugins = [s3_lightning_checkpoint ],
@@ -113,6 +120,7 @@ def test_compatibility_with_checkpoint_callback(checkpoint_directory):
113120
114121 model = LightningTransformer (vocab_size = dataset .vocab_size )
115122 s3_lightning_checkpoint = S3LightningCheckpoint (checkpoint_directory .region )
123+ _verify_user_agent (s3_lightning_checkpoint )
116124
117125 checkpoint_callback = ModelCheckpoint (
118126 dirpath = checkpoint_directory .s3_uri ,
@@ -140,6 +148,7 @@ def test_compatibility_with_checkpoint_callback(checkpoint_directory):
140148
141149 checkpoint_s3_uri = f"{ checkpoint_directory .s3_uri } { expected_checkpoint_name } "
142150 loaded_checkpoint = s3_lightning_checkpoint .load_checkpoint (checkpoint_s3_uri )
151+ _verify_user_agent (s3_lightning_checkpoint )
143152 _verify_equal_state_dict (model .state_dict (), loaded_checkpoint ["state_dict" ])
144153
145154
@@ -150,6 +159,7 @@ def test_compatibility_with_async_checkpoint_io(checkpoint_directory):
150159
151160 model = LightningTransformer (vocab_size = dataset .vocab_size )
152161 s3_lightning_checkpoint = S3LightningCheckpoint (checkpoint_directory .region )
162+ _verify_user_agent (s3_lightning_checkpoint )
153163 async_s3_lightning_checkpoint = AsyncCheckpointIO (s3_lightning_checkpoint )
154164
155165 trainer = L .Trainer (
@@ -168,6 +178,7 @@ def test_compatibility_with_async_checkpoint_io(checkpoint_directory):
168178 checkpoint_key = "lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt"
169179 checkpoint_s3_uri = f"{ checkpoint_directory .s3_uri } { checkpoint_key } "
170180 loaded_checkpoint = s3_lightning_checkpoint .load_checkpoint (checkpoint_s3_uri )
181+ _verify_user_agent (s3_lightning_checkpoint )
171182 _verify_equal_state_dict (model .state_dict (), loaded_checkpoint ["state_dict" ])
172183
173184
@@ -225,6 +236,7 @@ def test_nn_checkpointing(checkpoint_directory):
225236 # Assert that eval and train do not raise
226237 loaded_nn_model .eval ()
227238 loaded_nn_model .train ()
239+ _verify_user_agent (s3_lightning_checkpoint )
228240
229241
230242def _verify_equal_state_dict (
@@ -236,3 +248,10 @@ def _verify_equal_state_dict(
236248 # These are tuples (str, Tensor)
237249 assert model_key == loaded_key
238250 assert torch .equal (model_value , loaded_value )
251+
252+
253+ def _verify_user_agent (s3_lightning_checkpoint : S3LightningCheckpoint ):
254+ expected_user_agent = (
255+ f"s3torchconnector/{ __version__ } (lightning; { lightning .__version__ } )"
256+ )
257+ assert s3_lightning_checkpoint ._client .user_agent_prefix == expected_user_agent
0 commit comments