Skip to content

Commit 1fdf1c5

Browse files
Bordaaniketmauryapre-commit-ci[bot]
authored
adding Torch mixin (#65)
* adding Torch mixin * typing --------- Co-authored-by: Aniket Maurya <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9ca99c9 commit 1fdf1c5

File tree

2 files changed

+121
-2
lines changed

2 files changed

+121
-2
lines changed

src/litmodels/integrations/mixins.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import pickle
22
import tempfile
3+
import warnings
34
from abc import ABC
45
from pathlib import Path
5-
from typing import Optional
6+
from typing import TYPE_CHECKING, Optional
67

78
from litmodels import download_model, upload_model
89

10+
if TYPE_CHECKING:
11+
import torch
12+
913

1014
class ModelRegistryMixin(ABC):
1115
"""Mixin for model registry integration."""
@@ -83,3 +87,81 @@ def pull_from_registry(
8387
if not isinstance(obj, cls):
8488
raise RuntimeError(f"Unpickled object is not of type {cls.__name__}: {type(obj)}")
8589
return obj
90+
91+
92+
class PyTorchRegistryMixin(ABC):
93+
"""Mixin for PyTorch model registry integration."""
94+
95+
def __post_init__(self) -> None:
96+
"""Post-initialization method to set up the model."""
97+
import torch
98+
99+
# Ensure that the model is in evaluation mode
100+
if not isinstance(self, torch.nn.Module):
101+
raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(self)}")
102+
103+
def push_to_registry(
104+
self, model_name: Optional[str] = None, model_version: Optional[str] = None, temp_folder: Optional[str] = None
105+
) -> None:
106+
"""Push the model to the registry.
107+
108+
Args:
109+
model_name: The name of the model. If not use the class name.
110+
model_version: The version of the model. If None, the latest version is used.
111+
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
112+
"""
113+
import torch
114+
115+
if not isinstance(self, torch.nn.Module):
116+
raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(self)}")
117+
118+
if model_name is None:
119+
model_name = self.__class__.__name__
120+
if temp_folder is None:
121+
temp_folder = tempfile.gettempdir()
122+
torch_path = Path(temp_folder) / f"{model_name}.pth"
123+
torch.save(self.state_dict(), torch_path)
124+
# todo: dump also object creation arguments so we can dump it and load with model for object instantiation
125+
model_registry = f"{model_name}:{model_version}" if model_version else model_name
126+
upload_model(name=model_registry, model=torch_path)
127+
128+
@classmethod
129+
def pull_from_registry(
130+
cls,
131+
model_name: str,
132+
model_version: Optional[str] = None,
133+
temp_folder: Optional[str] = None,
134+
torch_load_kwargs: Optional[dict] = None,
135+
) -> "torch.nn.Module":
136+
"""Pull the model from the registry.
137+
138+
Args:
139+
model_name: The name of the model.
140+
model_version: The version of the model. If None, the latest version is used.
141+
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
142+
torch_load_kwargs: Additional arguments to pass to `torch.load()`.
143+
"""
144+
import torch
145+
146+
if temp_folder is None:
147+
temp_folder = tempfile.gettempdir()
148+
model_registry = f"{model_name}:{model_version}" if model_version else model_name
149+
files = download_model(name=model_registry, download_dir=temp_folder)
150+
torch_files = [f for f in files if f.endswith(".pth")]
151+
if not torch_files:
152+
raise RuntimeError(f"No torch file found for model: {model_registry} with {files}")
153+
if len(torch_files) > 1:
154+
raise RuntimeError(f"Multiple torch files found for model: {model_registry} with {torch_files}")
155+
state_dict_path = Path(temp_folder) / torch_files[0]
156+
# ignore future warning about changed default
157+
with warnings.catch_warnings():
158+
warnings.simplefilter("ignore", category=FutureWarning)
159+
state_dict = torch.load(state_dict_path, **(torch_load_kwargs if torch_load_kwargs else {}))
160+
161+
# Create a new model instance without calling __init__
162+
instance = cls() # todo: we need to add args used when created dumped model
163+
if not isinstance(instance, torch.nn.Module):
164+
raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(instance)}")
165+
# Now load the state dict on the instance
166+
instance.load_state_dict(state_dict, strict=True)
167+
return instance

tests/integrations/test_mixins.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from unittest import mock
22

3-
from litmodels.integrations.mixins import PickleRegistryMixin
3+
import torch
4+
from litmodels.integrations.mixins import PickleRegistryMixin, PyTorchRegistryMixin
5+
from torch import nn
46

57

68
class DummyModel(PickleRegistryMixin):
@@ -29,3 +31,38 @@ def test_pickle_push_and_pull(mock_download_model, mock_upload_model, tmp_path):
2931
)
3032
# Verify that the unpickled instance has the expected value.
3133
assert loaded_dummy.value == 42
34+
35+
36+
class DummyTorchModel(nn.Module, PyTorchRegistryMixin):
37+
def __init__(self, input_size=784):
38+
super().__init__()
39+
self.fc = nn.Linear(input_size, 10)
40+
41+
def forward(self, x):
42+
x = x.view(x.size(0), -1)
43+
return self.fc(x)
44+
45+
46+
@mock.patch("litmodels.integrations.mixins.upload_model")
47+
@mock.patch("litmodels.integrations.mixins.download_model")
48+
def test_pytorch_pull_updated(mock_download_model, mock_upload_model, tmp_path):
49+
# Create an instance, push the model and record its forward output.
50+
dummy = DummyTorchModel(784)
51+
dummy.eval()
52+
input_tensor = torch.randn(1, 784)
53+
output_before = dummy(input_tensor)
54+
55+
dummy.push_to_registry(temp_folder=str(tmp_path))
56+
expected_path = tmp_path / f"{dummy.__class__.__name__}.pth"
57+
mock_upload_model.assert_called_once_with(name="DummyTorchModel", model=expected_path)
58+
59+
torch.save(dummy.state_dict(), expected_path)
60+
# Prepare mocking for pull_from_registry.
61+
mock_download_model.return_value = [f"{dummy.__class__.__name__}.pth"]
62+
loaded_dummy = DummyTorchModel.pull_from_registry(model_name="DummyTorchModel", temp_folder=str(tmp_path))
63+
loaded_dummy.eval()
64+
output_after = loaded_dummy(input_tensor)
65+
66+
assert isinstance(loaded_dummy, DummyTorchModel)
67+
# Compare the outputs as a verification.
68+
assert torch.allclose(output_before, output_after), "Loaded model output differs from original."

0 commit comments

Comments
 (0)