Skip to content

Commit 2ce6344

Browse files
Bordapre-commit-ci[bot]Copilot
authored
fix: Pickle mixin & update args + test with Prod (#67)
* fix: Pickle mixin * update args * test with Prod * Apply suggestions from code review --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <[email protected]>
1 parent c876afe commit 2ce6344

File tree

3 files changed

+69
-44
lines changed

3 files changed

+69
-44
lines changed

src/litmodels/integrations/mixins.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,66 +15,67 @@ class ModelRegistryMixin(ABC):
1515
"""Mixin for model registry integration."""
1616

1717
def push_to_registry(
18-
self, model_name: Optional[str] = None, model_version: Optional[str] = None, temp_folder: Optional[str] = None
18+
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Optional[str] = None
1919
) -> None:
2020
"""Push the model to the registry.
2121
2222
Args:
23-
model_name: The name of the model. If not use the class name.
24-
model_version: The version of the model. If None, the latest version is used.
23+
name: The name of the model. If not use the class name.
24+
version: The version of the model. If None, the latest version is used.
2525
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
2626
"""
2727

2828
@classmethod
29-
def pull_from_registry(
30-
cls, model_name: str, model_version: Optional[str] = None, temp_folder: Optional[str] = None
31-
) -> object:
29+
def pull_from_registry(cls, name: str, version: Optional[str] = None, temp_folder: Optional[str] = None) -> object:
3230
"""Pull the model from the registry.
3331
3432
Args:
35-
model_name: The name of the model.
36-
model_version: The version of the model. If None, the latest version is used.
33+
name: The name of the model.
34+
version: The version of the model. If None, the latest version is used.
3735
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
3836
"""
3937

4038

41-
class PickleRegistryMixin(ABC):
39+
class PickleRegistryMixin(ModelRegistryMixin):
4240
"""Mixin for pickle registry integration."""
4341

4442
def push_to_registry(
45-
self, model_name: Optional[str] = None, model_version: Optional[str] = None, temp_folder: Optional[str] = None
43+
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Optional[str] = None
4644
) -> None:
4745
"""Push the model to the registry.
4846
4947
Args:
50-
model_name: The name of the model. If not use the class name.
51-
model_version: The version of the model. If None, the latest version is used.
48+
name: The name of the model. If not use the class name.
49+
version: The version of the model. If None, the latest version is used.
5250
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
5351
"""
54-
if model_name is None:
55-
model_name = self.__class__.__name__
52+
if name is None:
53+
name = model_name = self.__class__.__name__
54+
elif ":" in name:
55+
raise ValueError(f"Invalid model name: '{name}'. It should not contain ':' associated with version.")
56+
else:
57+
model_name = name.split("/")[-1]
5658
if temp_folder is None:
57-
temp_folder = tempfile.gettempdir()
59+
temp_folder = tempfile.mkdtemp()
5860
pickle_path = Path(temp_folder) / f"{model_name}.pkl"
5961
with open(pickle_path, "wb") as fp:
6062
pickle.dump(self, fp, protocol=pickle.HIGHEST_PROTOCOL)
61-
model_registry = f"{model_name}:{model_version}" if model_version else model_name
62-
upload_model(name=model_registry, model=pickle_path)
63+
if version:
64+
name = f"{name}:{version}"
65+
upload_model(name=name, model=pickle_path)
6366

6467
@classmethod
65-
def pull_from_registry(
66-
cls, model_name: str, model_version: Optional[str] = None, temp_folder: Optional[str] = None
67-
) -> object:
68+
def pull_from_registry(cls, name: str, version: Optional[str] = None, temp_folder: Optional[str] = None) -> object:
6869
"""Pull the model from the registry.
6970
7071
Args:
71-
model_name: The name of the model.
72-
model_version: The version of the model. If None, the latest version is used.
72+
name: The name of the model.
73+
version: The version of the model. If None, the latest version is used.
7374
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
7475
"""
7576
if temp_folder is None:
76-
temp_folder = tempfile.gettempdir()
77-
model_registry = f"{model_name}:{model_version}" if model_version else model_name
77+
temp_folder = tempfile.mkdtemp()
78+
model_registry = f"{name}:{version}" if version else name
7879
files = download_model(name=model_registry, download_dir=temp_folder)
7980
pkl_files = [f for f in files if f.endswith(".pkl")]
8081
if not pkl_files:
@@ -89,7 +90,7 @@ def pull_from_registry(
8990
return obj
9091

9192

92-
class PyTorchRegistryMixin(ABC):
93+
class PyTorchRegistryMixin(ModelRegistryMixin):
9394
"""Mixin for PyTorch model registry integration."""
9495

9596
def __post_init__(self) -> None:
@@ -101,51 +102,51 @@ def __post_init__(self) -> None:
101102
raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(self)}")
102103

103104
def push_to_registry(
104-
self, model_name: Optional[str] = None, model_version: Optional[str] = None, temp_folder: Optional[str] = None
105+
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Optional[str] = None
105106
) -> None:
106107
"""Push the model to the registry.
107108
108109
Args:
109-
model_name: The name of the model. If not use the class name.
110-
model_version: The version of the model. If None, the latest version is used.
110+
name: The name of the model. If not use the class name.
111+
version: The version of the model. If None, the latest version is used.
111112
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
112113
"""
113114
import torch
114115

115116
if not isinstance(self, torch.nn.Module):
116117
raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(self)}")
117118

118-
if model_name is None:
119-
model_name = self.__class__.__name__
119+
if name is None:
120+
name = self.__class__.__name__
120121
if temp_folder is None:
121-
temp_folder = tempfile.gettempdir()
122-
torch_path = Path(temp_folder) / f"{model_name}.pth"
122+
temp_folder = tempfile.mkdtemp()
123+
torch_path = Path(temp_folder) / f"{name}.pth"
123124
torch.save(self.state_dict(), torch_path)
124125
# todo: dump also object creation arguments so we can dump it and load with model for object instantiation
125-
model_registry = f"{model_name}:{model_version}" if model_version else model_name
126+
model_registry = f"{name}:{version}" if version else name
126127
upload_model(name=model_registry, model=torch_path)
127128

128129
@classmethod
129130
def pull_from_registry(
130131
cls,
131-
model_name: str,
132-
model_version: Optional[str] = None,
132+
name: str,
133+
version: Optional[str] = None,
133134
temp_folder: Optional[str] = None,
134135
torch_load_kwargs: Optional[dict] = None,
135136
) -> "torch.nn.Module":
136137
"""Pull the model from the registry.
137138
138139
Args:
139-
model_name: The name of the model.
140-
model_version: The version of the model. If None, the latest version is used.
140+
name: The name of the model.
141+
version: The version of the model. If None, the latest version is used.
141142
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
142143
torch_load_kwargs: Additional arguments to pass to `torch.load()`.
143144
"""
144145
import torch
145146

146147
if temp_folder is None:
147-
temp_folder = tempfile.gettempdir()
148-
model_registry = f"{model_name}:{model_version}" if model_version else model_name
148+
temp_folder = tempfile.mkdtemp()
149+
model_registry = f"{name}:{version}" if version else name
149150
files = download_model(name=model_registry, download_dir=temp_folder)
150151
torch_files = [f for f in files if f.endswith(".pth")]
151152
if not torch_files:

tests/integrations/test_cloud.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from lightning_sdk.lightning_cloud.rest_client import GridRestClient
1010
from lightning_sdk.utils.resolve import _resolve_teamspace
1111
from litmodels import download_model, upload_model
12+
from litmodels.integrations.mixins import PickleRegistryMixin
1213

1314
from tests.integrations import _SKIP_IF_LIGHTNING_BELLOW_2_5_1, _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1
1415

@@ -213,3 +214,28 @@ def test_lightning_checkpoint_ddp(importing, tmp_path):
213214

214215
# CLEANING
215216
_cleanup_model(teamspace, model_name, expected_num_versions=2)
217+
218+
219+
class DummyModel(PickleRegistryMixin):
220+
def __init__(self, value):
221+
self.value = value
222+
223+
224+
@pytest.mark.cloud()
225+
def test_pickle_mixin_push_and_pull():
226+
# model name with random hash
227+
teamspace, org_team, model_name = _prepare_variables("pickle_mixin")
228+
model_registry = f"{org_team}/{model_name}"
229+
230+
# Create an instance of DummyModel and call push_to_registry.
231+
dummy = DummyModel(42)
232+
dummy.push_to_registry(model_registry)
233+
234+
# Call pull_from_registry and load the DummyModel instance.
235+
loaded_dummy = DummyModel.pull_from_registry(model_registry)
236+
# Verify that the unpickled instance has the expected value.
237+
assert isinstance(loaded_dummy, DummyModel)
238+
assert loaded_dummy.value == 42
239+
240+
# CLEANING
241+
_cleanup_model(teamspace, model_name, expected_num_versions=1)

tests/integrations/test_mixins.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,15 @@ def __eq__(self, other):
1818
def test_pickle_push_and_pull(mock_download_model, mock_upload_model, tmp_path):
1919
# Create an instance of DummyModel and call push_to_registry.
2020
dummy = DummyModel(42)
21-
dummy.push_to_registry(model_version="v1", temp_folder=str(tmp_path))
21+
dummy.push_to_registry(version="v1", temp_folder=str(tmp_path))
2222
# The expected registry name is "dummy_model:v1" and the file should be placed in the temp folder.
2323
expected_path = tmp_path / "DummyModel.pkl"
2424
mock_upload_model.assert_called_once_with(name="DummyModel:v1", model=expected_path)
2525

2626
# Set the mock to return the full path to the pickle file.
2727
mock_download_model.return_value = ["DummyModel.pkl"]
2828
# Call pull_from_registry and load the DummyModel instance.
29-
loaded_dummy = DummyModel.pull_from_registry(
30-
model_name="dummy_model", model_version="v1", temp_folder=str(tmp_path)
31-
)
29+
loaded_dummy = DummyModel.pull_from_registry(name="dummy_model", version="v1", temp_folder=str(tmp_path))
3230
# Verify that the unpickled instance has the expected value.
3331
assert loaded_dummy.value == 42
3432

@@ -59,7 +57,7 @@ def test_pytorch_pull_updated(mock_download_model, mock_upload_model, tmp_path):
5957
torch.save(dummy.state_dict(), expected_path)
6058
# Prepare mocking for pull_from_registry.
6159
mock_download_model.return_value = [f"{dummy.__class__.__name__}.pth"]
62-
loaded_dummy = DummyTorchModel.pull_from_registry(model_name="DummyTorchModel", temp_folder=str(tmp_path))
60+
loaded_dummy = DummyTorchModel.pull_from_registry(name="DummyTorchModel", temp_folder=str(tmp_path))
6361
loaded_dummy.eval()
6462
output_after = loaded_dummy(input_tensor)
6563

0 commit comments

Comments
 (0)