Skip to content

Commit 7413e1b

Browse files
authored
Feature/handle pytorch load changes (#306)
* Introduce default parameter weights_only to load_checkpoint to be aligned with changes in PyTorch Lightning * Align test types with torch.load weights_only support
1 parent 0192628 commit 7413e1b

File tree

5 files changed

+12
-24
lines changed

5 files changed

+12
-24
lines changed

s3torchconnector/src/s3torchconnector/lightning/s3_lightning_checkpoint.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,12 @@ def load_checkpoint(
7373
s3reader = self._client.get_object(bucket, key)
7474
# FIXME - io.BufferedIOBase and typing.IO aren't compatible
7575
# See https://github.com/python/typeshed/issues/6077
76-
return torch.load(s3reader, map_location) # type: ignore
76+
77+
# Explicitly set weights_only=False to:
78+
# 1. Maintain backwards compatibility with older PyTorch versions where this was the default behavior
79+
# 2. Match PyTorch Lightning's implementation strategy for consistent behavior
80+
# Reference: https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/utilities/cloud_io.py#L36
81+
return torch.load(s3reader, map_location, weights_only=False) # type: ignore
7782

7883
def remove_checkpoint(
7984
self,

s3torchconnector/tst/e2e/test_e2e_s3_lightning_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_save_compatibility_with_s3_checkpoint(checkpoint_directory):
5757
s3_uri = f"{checkpoint_directory.s3_uri}{checkpoint_name}"
5858
s3_lightning_checkpoint.save_checkpoint(tensor, s3_uri)
5959
checkpoint = S3Checkpoint(region=checkpoint_directory.region)
60-
loaded_checkpoint = torch.load(checkpoint.reader(s3_uri))
60+
loaded_checkpoint = torch.load(checkpoint.reader(s3_uri), weights_only=False)
6161
assert torch.equal(tensor, loaded_checkpoint)
6262

6363

s3torchconnector/tst/e2e/test_e2e_s3checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_general_checkpointing(checkpoint_directory, tensor_dimensions):
2020
with checkpoint.writer(s3_uri) as writer:
2121
torch.save(tensor, writer)
2222

23-
loaded = torch.load(checkpoint.reader(s3_uri))
23+
loaded = torch.load(checkpoint.reader(s3_uri), weights_only=True)
2424

2525
assert torch.equal(tensor, loaded)
2626

@@ -49,7 +49,7 @@ def test_nn_checkpointing(checkpoint_directory):
4949
# assert models are not equal before loading from checkpoint
5050
assert not nn_model.equals(loaded_nn_model)
5151

52-
loaded_checkpoint = torch.load(checkpoint.reader(s3_uri))
52+
loaded_checkpoint = torch.load(checkpoint.reader(s3_uri), weights_only=True)
5353
loaded_nn_model.load_state_dict(loaded_checkpoint["model_state_dict"])
5454
assert nn_model.equals(loaded_nn_model)
5555

s3torchconnector/tst/unit/_checkpoint_byteorder_patch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ def save_with_byteorder(data, fobj, byteorder: str, use_modern_pytorch_format: b
2323

2424
def load_with_byteorder(fobj, byteorder):
2525
with _patch_byteorder(byteorder):
26-
return torch.load(fobj)
26+
return torch.load(fobj, weights_only=True)

s3torchconnector/tst/unit/_hypothesis_python_primitives.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,47 +3,30 @@
33

44
from hypothesis.strategies import (
55
integers,
6-
binary,
7-
none,
86
characters,
9-
complex_numbers,
107
floats,
118
booleans,
12-
decimals,
13-
fractions,
149
deferred,
15-
frozensets,
1610
tuples,
1711
dictionaries,
1812
lists,
19-
uuids,
20-
sets,
2113
text,
2214
)
2315

2416
scalars = (
25-
none()
26-
| booleans()
17+
booleans()
2718
| integers()
2819
# Disallow nan as it doesn't have self-equality
2920
| floats(allow_nan=False)
30-
| complex_numbers(allow_nan=False)
31-
| decimals(allow_nan=False)
32-
| fractions()
3321
| characters()
34-
| binary(max_size=10)
3522
| text(max_size=10)
36-
| uuids()
3723
)
3824

39-
hashable = deferred(
40-
lambda: (scalars | frozensets(hashable, max_size=5) | tuples(hashable))
41-
)
25+
hashable = deferred(lambda: (scalars | tuples(hashable)))
4226

4327
python_primitives = deferred(
4428
lambda: (
4529
hashable
46-
| sets(hashable, max_size=5)
4730
| lists(python_primitives, max_size=5)
4831
| dictionaries(keys=hashable, values=python_primitives, max_size=3)
4932
)

0 commit comments

Comments
 (0)