Skip to content

Commit 9ca99c9

Browse files
authored
adding Pickle mixin (#64)
* adding Pickle mixin * isinstance * compress * object
1 parent 3357462 commit 9ca99c9

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import pickle
2+
import tempfile
3+
from abc import ABC
4+
from pathlib import Path
5+
from typing import Optional
6+
7+
from litmodels import download_model, upload_model
8+
9+
10+
class ModelRegistryMixin(ABC):
11+
"""Mixin for model registry integration."""
12+
13+
def push_to_registry(
14+
self, model_name: Optional[str] = None, model_version: Optional[str] = None, temp_folder: Optional[str] = None
15+
) -> None:
16+
"""Push the model to the registry.
17+
18+
Args:
19+
model_name: The name of the model. If not use the class name.
20+
model_version: The version of the model. If None, the latest version is used.
21+
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
22+
"""
23+
24+
@classmethod
25+
def pull_from_registry(
26+
cls, model_name: str, model_version: Optional[str] = None, temp_folder: Optional[str] = None
27+
) -> object:
28+
"""Pull the model from the registry.
29+
30+
Args:
31+
model_name: The name of the model.
32+
model_version: The version of the model. If None, the latest version is used.
33+
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
34+
"""
35+
36+
37+
class PickleRegistryMixin(ABC):
38+
"""Mixin for pickle registry integration."""
39+
40+
def push_to_registry(
41+
self, model_name: Optional[str] = None, model_version: Optional[str] = None, temp_folder: Optional[str] = None
42+
) -> None:
43+
"""Push the model to the registry.
44+
45+
Args:
46+
model_name: The name of the model. If not use the class name.
47+
model_version: The version of the model. If None, the latest version is used.
48+
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
49+
"""
50+
if model_name is None:
51+
model_name = self.__class__.__name__
52+
if temp_folder is None:
53+
temp_folder = tempfile.gettempdir()
54+
pickle_path = Path(temp_folder) / f"{model_name}.pkl"
55+
with open(pickle_path, "wb") as fp:
56+
pickle.dump(self, fp, protocol=pickle.HIGHEST_PROTOCOL)
57+
model_registry = f"{model_name}:{model_version}" if model_version else model_name
58+
upload_model(name=model_registry, model=pickle_path)
59+
60+
@classmethod
61+
def pull_from_registry(
62+
cls, model_name: str, model_version: Optional[str] = None, temp_folder: Optional[str] = None
63+
) -> object:
64+
"""Pull the model from the registry.
65+
66+
Args:
67+
model_name: The name of the model.
68+
model_version: The version of the model. If None, the latest version is used.
69+
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
70+
"""
71+
if temp_folder is None:
72+
temp_folder = tempfile.gettempdir()
73+
model_registry = f"{model_name}:{model_version}" if model_version else model_name
74+
files = download_model(name=model_registry, download_dir=temp_folder)
75+
pkl_files = [f for f in files if f.endswith(".pkl")]
76+
if not pkl_files:
77+
raise RuntimeError(f"No pickle file found for model: {model_registry} with {files}")
78+
if len(pkl_files) > 1:
79+
raise RuntimeError(f"Multiple pickle files found for model: {model_registry} with {pkl_files}")
80+
pkl_path = Path(temp_folder) / pkl_files[0]
81+
with open(pkl_path, "rb") as fp:
82+
obj = pickle.load(fp)
83+
if not isinstance(obj, cls):
84+
raise RuntimeError(f"Unpickled object is not of type {cls.__name__}: {type(obj)}")
85+
return obj

tests/integrations/test_mixins.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from unittest import mock
2+
3+
from litmodels.integrations.mixins import PickleRegistryMixin
4+
5+
6+
class DummyModel(PickleRegistryMixin):
7+
def __init__(self, value):
8+
self.value = value
9+
10+
def __eq__(self, other):
11+
return isinstance(other, DummyModel) and self.value == other.value
12+
13+
14+
@mock.patch("litmodels.integrations.mixins.upload_model")
15+
@mock.patch("litmodels.integrations.mixins.download_model")
16+
def test_pickle_push_and_pull(mock_download_model, mock_upload_model, tmp_path):
17+
# Create an instance of DummyModel and call push_to_registry.
18+
dummy = DummyModel(42)
19+
dummy.push_to_registry(model_version="v1", temp_folder=str(tmp_path))
20+
# The expected registry name is "dummy_model:v1" and the file should be placed in the temp folder.
21+
expected_path = tmp_path / "DummyModel.pkl"
22+
mock_upload_model.assert_called_once_with(name="DummyModel:v1", model=expected_path)
23+
24+
# Set the mock to return the full path to the pickle file.
25+
mock_download_model.return_value = ["DummyModel.pkl"]
26+
# Call pull_from_registry and load the DummyModel instance.
27+
loaded_dummy = DummyModel.pull_from_registry(
28+
model_name="dummy_model", model_version="v1", temp_folder=str(tmp_path)
29+
)
30+
# Verify that the unpickled instance has the expected value.
31+
assert loaded_dummy.value == 42

0 commit comments

Comments
 (0)