Skip to content

Commit 640a154

Browse files
authored
Added support for passing a custom endpoint to S3LightningCheckpoint (#201)
* Added support for passing a custom endpoint to S3LightningCheckpoint * Updated CHANGELOG
1 parent 463691b commit 640a154

File tree

3 files changed

+9
-0
lines changed

3 files changed

+9
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
### New features
44
* Update S3ClientConfig to pass in the configuration for allowing unsigned requests, under boolean flag `unsigned`.
55
* Improve the performance of `s3reader` when utilized with `pytorch.load` by incorporating support for the `readinto` method.
6+
* Add support for passing an optional custom endpoint to S3LightningCheckpoint constructor method.
67

78

89
## v1.2.2 (March 22, 2024)

s3torchconnector/src/s3torchconnector/lightning/s3_lightning_checkpoint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ def __init__(
2020
self,
2121
region: str,
2222
s3client_config: Optional[S3ClientConfig] = None,
23+
endpoint: Optional[str] = None,
2324
):
2425
self.region = region
2526
user_agent = UserAgent(["lightning", lightning.__version__])
2627
self._client = S3Client(
2728
region,
2829
user_agent=user_agent,
2930
s3client_config=s3client_config,
31+
endpoint=endpoint,
3032
)
3133

3234
def save_checkpoint(

s3torchconnector/tst/unit/lightning/test_s3_lightning_checkpoint.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
TEST_BUCKET = "test-bucket"
2828
TEST_KEY = "test-key"
2929
TEST_REGION = "us-east-1"
30+
TEST_ENDPOINT = "https://s3.us-east-1.amazonaws.com"
3031

3132

3233
@pytest.fixture()
@@ -148,6 +149,11 @@ def test_teardown(lightning_checkpoint):
148149
# Assert no exception is thrown - implicit
149150

150151

152+
def test_lightning_checkpoint_creation_with_region_and_endpoint():
153+
checkpoint = S3LightningCheckpoint(TEST_REGION, endpoint=TEST_ENDPOINT)
154+
assert isinstance(checkpoint, S3LightningCheckpoint)
155+
156+
151157
def _test_save(
152158
client,
153159
checkpoint: CheckpointIO,

0 commit comments

Comments
 (0)