Skip to content

Commit 44113b6

Browse files
committed
fix: Pickle mixin & update args + test with Prod
1 parent c876afe commit 44113b6

File tree

3 files changed

+68
-37
lines changed

3 files changed

+68
-37
lines changed

src/litmodels/integrations/mixins.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,66 +15,71 @@ class ModelRegistryMixin(ABC):
1515
"""Mixin for model registry integration."""
1616

1717
def push_to_registry(
18-
self, model_name: Optional[str] = None, model_version: Optional[str] = None, temp_folder: Optional[str] = None
18+
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Optional[str] = None
1919
) -> None:
2020
"""Push the model to the registry.
2121
2222
Args:
23-
model_name: The name of the model. If not use the class name.
24-
model_version: The version of the model. If None, the latest version is used.
23+
name: The name of the model. If not use the class name.
24+
version: The version of the model. If None, the latest version is used.
2525
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
2626
"""
2727

2828
@classmethod
2929
def pull_from_registry(
30-
cls, model_name: str, model_version: Optional[str] = None, temp_folder: Optional[str] = None
30+
cls, name: str, version: Optional[str] = None, temp_folder: Optional[str] = None
3131
) -> object:
3232
"""Pull the model from the registry.
3333
3434
Args:
35-
model_name: The name of the model.
36-
model_version: The version of the model. If None, the latest version is used.
35+
name: The name of the model.
36+
version: The version of the model. If None, the latest version is used.
3737
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
3838
"""
3939

4040

41-
class PickleRegistryMixin(ABC):
41+
class PickleRegistryMixin(ModelRegistryMixin):
4242
"""Mixin for pickle registry integration."""
4343

4444
def push_to_registry(
45-
self, model_name: Optional[str] = None, model_version: Optional[str] = None, temp_folder: Optional[str] = None
45+
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Optional[str] = None
4646
) -> None:
4747
"""Push the model to the registry.
4848
4949
Args:
50-
model_name: The name of the model. If not use the class name.
51-
model_version: The version of the model. If None, the latest version is used.
50+
name: The name of the model. If not use the class name.
51+
version: The version of the model. If None, the latest version is used.
5252
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
5353
"""
54-
if model_name is None:
55-
model_name = self.__class__.__name__
54+
if ":" in name:
55+
raise ValueError(f"Invalid model name: '{name}'. It should not contain ':' associated with version.")
56+
if name is None:
57+
name = model_name = self.__class__.__name__
58+
else:
59+
model_name = name.split("/")[-1]
5660
if temp_folder is None:
57-
temp_folder = tempfile.gettempdir()
61+
temp_folder = tempfile.mkdtemp()
5862
pickle_path = Path(temp_folder) / f"{model_name}.pkl"
5963
with open(pickle_path, "wb") as fp:
6064
pickle.dump(self, fp, protocol=pickle.HIGHEST_PROTOCOL)
61-
model_registry = f"{model_name}:{model_version}" if model_version else model_name
62-
upload_model(name=model_registry, model=pickle_path)
65+
if version:
66+
name = f"{name}:{version}"
67+
upload_model(name=name, model=pickle_path)
6368

6469
@classmethod
6570
def pull_from_registry(
66-
cls, model_name: str, model_version: Optional[str] = None, temp_folder: Optional[str] = None
71+
cls, name: str, version: Optional[str] = None, temp_folder: Optional[str] = None
6772
) -> object:
6873
"""Pull the model from the registry.
6974
7075
Args:
71-
model_name: The name of the model.
72-
model_version: The version of the model. If None, the latest version is used.
76+
name: The name of the model.
77+
version: The version of the model. If None, the latest version is used.
7378
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
7479
"""
7580
if temp_folder is None:
76-
temp_folder = tempfile.gettempdir()
77-
model_registry = f"{model_name}:{model_version}" if model_version else model_name
81+
temp_folder = tempfile.mkdtemp()
82+
model_registry = f"{name}:{version}" if version else name
7883
files = download_model(name=model_registry, download_dir=temp_folder)
7984
pkl_files = [f for f in files if f.endswith(".pkl")]
8085
if not pkl_files:
@@ -89,7 +94,7 @@ def pull_from_registry(
8994
return obj
9095

9196

92-
class PyTorchRegistryMixin(ABC):
97+
class PyTorchRegistryMixin(ModelRegistryMixin):
9398
"""Mixin for PyTorch model registry integration."""
9499

95100
def __post_init__(self) -> None:
@@ -101,51 +106,51 @@ def __post_init__(self) -> None:
101106
raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(self)}")
102107

103108
def push_to_registry(
104-
self, model_name: Optional[str] = None, model_version: Optional[str] = None, temp_folder: Optional[str] = None
109+
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Optional[str] = None
105110
) -> None:
106111
"""Push the model to the registry.
107112
108113
Args:
109-
model_name: The name of the model. If not use the class name.
110-
model_version: The version of the model. If None, the latest version is used.
114+
name: The name of the model. If not use the class name.
115+
version: The version of the model. If None, the latest version is used.
111116
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
112117
"""
113118
import torch
114119

115120
if not isinstance(self, torch.nn.Module):
116121
raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(self)}")
117122

118-
if model_name is None:
119-
model_name = self.__class__.__name__
123+
if name is None:
124+
name = self.__class__.__name__
120125
if temp_folder is None:
121-
temp_folder = tempfile.gettempdir()
122-
torch_path = Path(temp_folder) / f"{model_name}.pth"
126+
temp_folder = tempfile.mkdtemp()
127+
torch_path = Path(temp_folder) / f"{name}.pth"
123128
torch.save(self.state_dict(), torch_path)
124129
# todo: dump also object creation arguments so we can dump it and load with model for object instantiation
125-
model_registry = f"{model_name}:{model_version}" if model_version else model_name
130+
model_registry = f"{name}:{version}" if version else name
126131
upload_model(name=model_registry, model=torch_path)
127132

128133
@classmethod
129134
def pull_from_registry(
130135
cls,
131-
model_name: str,
132-
model_version: Optional[str] = None,
136+
name: str,
137+
version: Optional[str] = None,
133138
temp_folder: Optional[str] = None,
134139
torch_load_kwargs: Optional[dict] = None,
135140
) -> "torch.nn.Module":
136141
"""Pull the model from the registry.
137142
138143
Args:
139-
model_name: The name of the model.
140-
model_version: The version of the model. If None, the latest version is used.
144+
name: The name of the model.
145+
version: The version of the model. If None, the latest version is used.
141146
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
142147
torch_load_kwargs: Additional arguments to pass to `torch.load()`.
143148
"""
144149
import torch
145150

146151
if temp_folder is None:
147-
temp_folder = tempfile.gettempdir()
148-
model_registry = f"{model_name}:{model_version}" if model_version else model_name
152+
temp_folder = tempfile.mkdtemp()
153+
model_registry = f"{name}:{version}" if version else name
149154
files = download_model(name=model_registry, download_dir=temp_folder)
150155
torch_files = [f for f in files if f.endswith(".pth")]
151156
if not torch_files:

tests/integrations/test_cloud.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from lightning_sdk.lightning_cloud.rest_client import GridRestClient
1010
from lightning_sdk.utils.resolve import _resolve_teamspace
1111
from litmodels import download_model, upload_model
12+
from litmodels.integrations.mixins import PickleRegistryMixin
1213

1314
from tests.integrations import _SKIP_IF_LIGHTNING_BELLOW_2_5_1, _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1
1415

@@ -213,3 +214,28 @@ def test_lightning_checkpoint_ddp(importing, tmp_path):
213214

214215
# CLEANING
215216
_cleanup_model(teamspace, model_name, expected_num_versions=2)
217+
218+
219+
class DummyModel(PickleRegistryMixin):
220+
def __init__(self, value):
221+
self.value = value
222+
223+
224+
@pytest.mark.cloud()
225+
def test_picklemixin_push_and_pull():
226+
# model name with random hash
227+
teamspace, org_team, model_name = _prepare_variables("picklemixin")
228+
model_registry = f"{org_team}/{model_name}"
229+
230+
# Create an instance of DummyModel and call push_to_registry.
231+
dummy = DummyModel(42)
232+
dummy.push_to_registry(model_registry)
233+
234+
# Call pull_from_registry and load the DummyModel instance.
235+
loaded_dummy = DummyModel.pull_from_registry(model_registry)
236+
# Verify that the unpickled instance has the expected value.
237+
assert isinstance(loaded_dummy, DummyModel)
238+
assert loaded_dummy.value == 42
239+
240+
# CLEANING
241+
_cleanup_model(teamspace, model_name, expected_num_versions=1)

tests/integrations/test_mixins.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_pickle_push_and_pull(mock_download_model, mock_upload_model, tmp_path):
2727
mock_download_model.return_value = ["DummyModel.pkl"]
2828
# Call pull_from_registry and load the DummyModel instance.
2929
loaded_dummy = DummyModel.pull_from_registry(
30-
model_name="dummy_model", model_version="v1", temp_folder=str(tmp_path)
30+
name="dummy_model", version="v1", temp_folder=str(tmp_path)
3131
)
3232
# Verify that the unpickled instance has the expected value.
3333
assert loaded_dummy.value == 42
@@ -59,7 +59,7 @@ def test_pytorch_pull_updated(mock_download_model, mock_upload_model, tmp_path):
5959
torch.save(dummy.state_dict(), expected_path)
6060
# Prepare mocking for pull_from_registry.
6161
mock_download_model.return_value = [f"{dummy.__class__.__name__}.pth"]
62-
loaded_dummy = DummyTorchModel.pull_from_registry(model_name="DummyTorchModel", temp_folder=str(tmp_path))
62+
loaded_dummy = DummyTorchModel.pull_from_registry(name="DummyTorchModel", temp_folder=str(tmp_path))
6363
loaded_dummy.eval()
6464
output_after = loaded_dummy(input_tensor)
6565

0 commit comments

Comments
 (0)