Skip to content

Commit 36e2a2c

Browse files
fix: Torch mixin + test with Prod (#68)
* test * fixing --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2ce6344 commit 36e2a2c

File tree

3 files changed

+55
-24
lines changed

3 files changed

+55
-24
lines changed

src/litmodels/integrations/mixins.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from abc import ABC
55
from pathlib import Path
6-
from typing import TYPE_CHECKING, Optional
6+
from typing import TYPE_CHECKING, Optional, Tuple
77

88
from litmodels import download_model, upload_model
99

@@ -35,6 +35,18 @@ def pull_from_registry(cls, name: str, version: Optional[str] = None, temp_folde
3535
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
3636
"""
3737

38+
def _setup(self, name: Optional[str] = None, temp_folder: Optional[str] = None) -> Tuple[str, str, str]:
39+
"""Parse and validate the model name and temporary folder."""
40+
if name is None:
41+
name = model_name = self.__class__.__name__
42+
elif ":" in name:
43+
raise ValueError(f"Invalid model name: '{name}'. It should not contain ':' associated with version.")
44+
else:
45+
model_name = name.split("/")[-1]
46+
if temp_folder is None:
47+
temp_folder = tempfile.mkdtemp()
48+
return name, model_name, temp_folder
49+
3850

3951
class PickleRegistryMixin(ModelRegistryMixin):
4052
"""Mixin for pickle registry integration."""
@@ -49,14 +61,7 @@ def push_to_registry(
4961
version: The version of the model. If None, the latest version is used.
5062
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
5163
"""
52-
if name is None:
53-
name = model_name = self.__class__.__name__
54-
elif ":" in name:
55-
raise ValueError(f"Invalid model name: '{name}'. It should not contain ':' associated with version.")
56-
else:
57-
model_name = name.split("/")[-1]
58-
if temp_folder is None:
59-
temp_folder = tempfile.mkdtemp()
64+
name, model_name, temp_folder = self._setup(name, temp_folder)
6065
pickle_path = Path(temp_folder) / f"{model_name}.pkl"
6166
with open(pickle_path, "wb") as fp:
6267
pickle.dump(self, fp, protocol=pickle.HIGHEST_PROTOCOL)
@@ -93,14 +98,6 @@ def pull_from_registry(cls, name: str, version: Optional[str] = None, temp_folde
9398
class PyTorchRegistryMixin(ModelRegistryMixin):
9499
"""Mixin for PyTorch model registry integration."""
95100

96-
def __post_init__(self) -> None:
97-
"""Post-initialization method to set up the model."""
98-
import torch
99-
100-
# Ensure that the model is in evaluation mode
101-
if not isinstance(self, torch.nn.Module):
102-
raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(self)}")
103-
104101
def push_to_registry(
105102
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Optional[str] = None
106103
) -> None:
@@ -116,11 +113,8 @@ def push_to_registry(
116113
if not isinstance(self, torch.nn.Module):
117114
raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(self)}")
118115

119-
if name is None:
120-
name = self.__class__.__name__
121-
if temp_folder is None:
122-
temp_folder = tempfile.mkdtemp()
123-
torch_path = Path(temp_folder) / f"{name}.pth"
116+
name, model_name, temp_folder = self._setup(name, temp_folder)
117+
torch_path = Path(temp_folder) / f"{model_name}.pth"
124118
torch.save(self.state_dict(), torch_path)
125119
# todo: dump also object creation arguments so we can dump it and load with model for object instantiation
126120
model_registry = f"{name}:{version}" if version else name

tests/integrations/test_cloud.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from typing import Optional
66

77
import pytest
8+
import torch
89
from lightning_sdk import Teamspace
910
from lightning_sdk.lightning_cloud.rest_client import GridRestClient
1011
from lightning_sdk.utils.resolve import _resolve_teamspace
1112
from litmodels import download_model, upload_model
12-
from litmodels.integrations.mixins import PickleRegistryMixin
13+
from litmodels.integrations.mixins import PickleRegistryMixin, PyTorchRegistryMixin
1314

1415
from tests.integrations import _SKIP_IF_LIGHTNING_BELLOW_2_5_1, _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1
1516

@@ -239,3 +240,39 @@ def test_pickle_mixin_push_and_pull():
239240

240241
# CLEANING
241242
_cleanup_model(teamspace, model_name, expected_num_versions=1)
243+
244+
245+
class DummyTorchModel(torch.nn.Module, PyTorchRegistryMixin):
246+
def __init__(self, input_size=784):
247+
super().__init__()
248+
self.fc = torch.nn.Linear(input_size, 10)
249+
250+
def forward(self, x):
251+
x = x.view(x.size(0), -1)
252+
return self.fc(x)
253+
254+
255+
@pytest.mark.cloud()
256+
def test_pytorch_mixin_push_and_pull():
257+
# model name with random hash
258+
teamspace, org_team, model_name = _prepare_variables("torch_mixin")
259+
model_registry = f"{org_team}/{model_name}"
260+
261+
# Create an instance, push the model and record its forward output.
262+
dummy = DummyTorchModel(784)
263+
dummy.eval()
264+
input_tensor = torch.randn(1, 784)
265+
output_before = dummy(input_tensor)
266+
267+
dummy.push_to_registry(model_registry)
268+
269+
loaded_dummy = DummyTorchModel.pull_from_registry(model_registry)
270+
loaded_dummy.eval()
271+
output_after = loaded_dummy(input_tensor)
272+
273+
assert isinstance(loaded_dummy, DummyTorchModel)
274+
# Compare the outputs as a verification.
275+
assert torch.allclose(output_before, output_after), "Loaded model output differs from original."
276+
277+
# CLEANING
278+
_cleanup_model(teamspace, model_name, expected_num_versions=1)

tests/integrations/test_mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def forward(self, x):
4343

4444
@mock.patch("litmodels.integrations.mixins.upload_model")
4545
@mock.patch("litmodels.integrations.mixins.download_model")
46-
def test_pytorch_pull_updated(mock_download_model, mock_upload_model, tmp_path):
46+
def test_pytorch_push_and_pull(mock_download_model, mock_upload_model, tmp_path):
4747
# Create an instance, push the model and record its forward output.
4848
dummy = DummyTorchModel(784)
4949
dummy.eval()

0 commit comments

Comments
 (0)