Skip to content

Commit bd2ec00

Browse files
Bordapre-commit-ci[bot]Copilot
authored
enable Torch mixin with init args (#69)
* fixing PT mixin * update + test * Apply suggestions from code review --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <[email protected]>
1 parent 36e2a2c commit bd2ec00

File tree

5 files changed

+128
-48
lines changed

5 files changed

+128
-48
lines changed

src/litmodels/integrations/mixins.py

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
import inspect
2+
import json
13
import pickle
24
import tempfile
35
import warnings
46
from abc import ABC
57
from pathlib import Path
6-
from typing import TYPE_CHECKING, Optional, Tuple
8+
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
79

8-
from litmodels import download_model, upload_model
10+
from lightning_utilities.core.rank_zero import rank_zero_warn
11+
12+
from litmodels.io.cloud import download_model_files, upload_model_files
913

1014
if TYPE_CHECKING:
1115
import torch
@@ -15,7 +19,7 @@ class ModelRegistryMixin(ABC):
1519
"""Mixin for model registry integration."""
1620

1721
def push_to_registry(
18-
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Optional[str] = None
22+
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
1923
) -> None:
2024
"""Push the model to the registry.
2125
@@ -26,7 +30,9 @@ def push_to_registry(
2630
"""
2731

2832
@classmethod
29-
def pull_from_registry(cls, name: str, version: Optional[str] = None, temp_folder: Optional[str] = None) -> object:
33+
def pull_from_registry(
34+
cls, name: str, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
35+
) -> object:
3036
"""Pull the model from the registry.
3137
3238
Args:
@@ -35,7 +41,9 @@ def pull_from_registry(cls, name: str, version: Optional[str] = None, temp_folde
3541
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
3642
"""
3743

38-
def _setup(self, name: Optional[str] = None, temp_folder: Optional[str] = None) -> Tuple[str, str, str]:
44+
def _setup(
45+
self, name: Optional[str] = None, temp_folder: Union[str, Path, None] = None
46+
) -> Tuple[str, str, Union[str, Path]]:
3947
"""Parse and validate the model name and temporary folder."""
4048
if name is None:
4149
name = model_name = self.__class__.__name__
@@ -52,7 +60,7 @@ class PickleRegistryMixin(ModelRegistryMixin):
5260
"""Mixin for pickle registry integration."""
5361

5462
def push_to_registry(
55-
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Optional[str] = None
63+
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
5664
) -> None:
5765
"""Push the model to the registry.
5866
@@ -67,10 +75,12 @@ def push_to_registry(
6775
pickle.dump(self, fp, protocol=pickle.HIGHEST_PROTOCOL)
6876
if version:
6977
name = f"{name}:{version}"
70-
upload_model(name=name, model=pickle_path)
78+
upload_model_files(name=name, path=pickle_path)
7179

7280
@classmethod
73-
def pull_from_registry(cls, name: str, version: Optional[str] = None, temp_folder: Optional[str] = None) -> object:
81+
def pull_from_registry(
82+
cls, name: str, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
83+
) -> object:
7484
"""Pull the model from the registry.
7585
7686
Args:
@@ -81,7 +91,7 @@ def pull_from_registry(cls, name: str, version: Optional[str] = None, temp_folde
8191
if temp_folder is None:
8292
temp_folder = tempfile.mkdtemp()
8393
model_registry = f"{name}:{version}" if version else name
84-
files = download_model(name=model_registry, download_dir=temp_folder)
94+
files = download_model_files(name=model_registry, download_dir=temp_folder)
8595
pkl_files = [f for f in files if f.endswith(".pkl")]
8696
if not pkl_files:
8797
raise RuntimeError(f"No pickle file found for model: {model_registry} with {files}")
@@ -98,8 +108,27 @@ def pull_from_registry(cls, name: str, version: Optional[str] = None, temp_folde
98108
class PyTorchRegistryMixin(ModelRegistryMixin):
99109
"""Mixin for PyTorch model registry integration."""
100110

111+
def __new__(cls, *args: Any, **kwargs: Any) -> "torch.nn.Module":
112+
"""Create a new instance of the class without calling __init__."""
113+
instance = super().__new__(cls)
114+
115+
# Get __init__ signature excluding 'self'
116+
init_sig = inspect.signature(cls.__init__)
117+
params = list(init_sig.parameters.values())[1:] # Skip self
118+
119+
# Create temporary signature for binding
120+
temp_sig = init_sig.replace(parameters=params)
121+
122+
# Bind and apply defaults
123+
bound_args = temp_sig.bind(*args, **kwargs)
124+
bound_args.apply_defaults()
125+
126+
# Store unified kwargs
127+
instance.__init_kwargs = bound_args.arguments
128+
return instance
129+
101130
def push_to_registry(
102-
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Optional[str] = None
131+
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
103132
) -> None:
104133
"""Push the model to the registry.
105134
@@ -110,22 +139,43 @@ def push_to_registry(
110139
"""
111140
import torch
112141

142+
# Ensure that the model is in evaluation mode
113143
if not isinstance(self, torch.nn.Module):
114144
raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(self)}")
115145

116146
name, model_name, temp_folder = self._setup(name, temp_folder)
117-
torch_path = Path(temp_folder) / f"{model_name}.pth"
118-
torch.save(self.state_dict(), torch_path)
119-
# todo: dump also object creation arguments so we can dump it and load with model for object instantiation
147+
148+
if self.__init_kwargs:
149+
try:
150+
# Save the model arguments to a JSON file
151+
init_kwargs_path = Path(temp_folder) / f"{model_name}__init_kwargs.json"
152+
with open(init_kwargs_path, "w") as fp:
153+
json.dump(self.__init_kwargs, fp)
154+
except Exception as e:
155+
raise RuntimeError(
156+
f"Failed to save model arguments: {e}."
157+
" Ensure the model's arguments are JSON serializable or use `PickleRegistryMixin`."
158+
) from e
159+
elif not hasattr(self, "__init_kwargs"):
160+
rank_zero_warn(
161+
"The child class is missing `__init_kwargs`."
162+
" Ensure `PyTorchRegistryMixin` is first in the inheritance order"
163+
" or call `PyTorchRegistryMixin.__init__` explicitly in the child class."
164+
)
165+
166+
torch_state_dict_path = Path(temp_folder) / f"{model_name}.pth"
167+
torch.save(self.state_dict(), torch_state_dict_path)
120168
model_registry = f"{name}:{version}" if version else name
121-
upload_model(name=model_registry, model=torch_path)
169+
# todo: consider creating another temp folder and copying these two files
170+
# todo: updating SDK to support uploading just specific files
171+
upload_model_files(name=model_registry, path=temp_folder)
122172

123173
@classmethod
124174
def pull_from_registry(
125175
cls,
126176
name: str,
127177
version: Optional[str] = None,
128-
temp_folder: Optional[str] = None,
178+
temp_folder: Union[str, Path, None] = None,
129179
torch_load_kwargs: Optional[dict] = None,
130180
) -> "torch.nn.Module":
131181
"""Pull the model from the registry.
@@ -141,7 +191,8 @@ def pull_from_registry(
141191
if temp_folder is None:
142192
temp_folder = tempfile.mkdtemp()
143193
model_registry = f"{name}:{version}" if version else name
144-
files = download_model(name=model_registry, download_dir=temp_folder)
194+
files = download_model_files(name=model_registry, download_dir=temp_folder)
195+
145196
torch_files = [f for f in files if f.endswith(".pth")]
146197
if not torch_files:
147198
raise RuntimeError(f"No torch file found for model: {model_registry} with {files}")
@@ -153,8 +204,18 @@ def pull_from_registry(
153204
warnings.simplefilter("ignore", category=FutureWarning)
154205
state_dict = torch.load(state_dict_path, **(torch_load_kwargs if torch_load_kwargs else {}))
155206

207+
init_files = [fp for fp in files if fp.endswith("__init_kwargs.json")]
208+
if not init_files:
209+
init_kwargs = {}
210+
elif len(init_files) > 1:
211+
raise RuntimeError(f"Multiple init files found for model: {model_registry} with {init_files}")
212+
else:
213+
init_kwargs_path = Path(temp_folder) / init_files[0]
214+
with open(init_kwargs_path) as fp:
215+
init_kwargs = json.load(fp)
216+
156217
# Create a new model instance without calling __init__
157-
instance = cls() # todo: we need to add args used when created dumped model
218+
instance = cls(**init_kwargs)
158219
if not isinstance(instance, torch.nn.Module):
159220
raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(instance)}")
160221
# Now load the state dict on the instance

src/litmodels/io/cloud.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# http://www.apache.org/licenses/LICENSE-2.0
44
#
5+
from pathlib import Path
56
from typing import TYPE_CHECKING, List, Optional, Union
67

78
from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL
@@ -41,7 +42,7 @@ def _print_model_link(name: str, verbose: Union[bool, int]) -> None:
4142

4243
def upload_model_files(
4344
name: str,
44-
path: str,
45+
path: Union[str, Path],
4546
progress_bar: bool = True,
4647
cloud_account: Optional[str] = None,
4748
verbose: Union[bool, int] = 1,
@@ -71,7 +72,7 @@ def upload_model_files(
7172

7273
def download_model_files(
7374
name: str,
74-
download_dir: str = ".",
75+
download_dir: Union[str, Path] = ".",
7576
progress_bar: bool = True,
7677
) -> Union[str, List[str]]:
7778
"""Download a checkpoint from the model store.

src/litmodels/io/gateway.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
if module_available("torch"):
1212
import torch
13-
from torch.nn import Module
1413
else:
1514
torch = None
1615

@@ -20,7 +19,7 @@
2019

2120
def upload_model(
2221
name: str,
23-
model: Union[str, Path, "Module", Any],
22+
model: Union[str, Path, "torch.nn.Module", Any],
2423
progress_bar: bool = True,
2524
cloud_account: Optional[str] = None,
2625
staging_dir: Optional[str] = None,
@@ -42,19 +41,17 @@ def upload_model(
4241
"""
4342
if not staging_dir:
4443
staging_dir = tempfile.mkdtemp()
44+
if isinstance(model, (str, Path)):
45+
path = model
4546
# if LightningModule and isinstance(model, LightningModule):
4647
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
4748
# model.save_checkpoint(path)
48-
if torch and isinstance(model, torch.jit.ScriptModule):
49+
elif torch and isinstance(model, torch.jit.ScriptModule):
4950
path = os.path.join(staging_dir, f"{model.__class__.__name__}.ts")
5051
model.save(path)
51-
elif torch and isinstance(model, Module):
52+
elif torch and isinstance(model, torch.nn.Module):
5253
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
5354
torch.save(model.state_dict(), path)
54-
elif isinstance(model, str):
55-
path = model
56-
elif isinstance(model, Path):
57-
path = str(model)
5855
else:
5956
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pkl")
6057
joblib.dump(model, path)
@@ -70,7 +67,7 @@ def upload_model(
7067

7168
def download_model(
7269
name: str,
73-
download_dir: str = ".",
70+
download_dir: Union[str, Path] = ".",
7471
progress_bar: bool = True,
7572
) -> Union[str, List[str]]:
7673
"""Download a checkpoint from the model store.
@@ -109,7 +106,7 @@ def load_model(name: str, download_dir: str = ".") -> Any:
109106
download_paths = [p for p in download_paths if Path(p).suffix.lower() not in {".md", ".txt", ".rst"}]
110107
if len(download_paths) > 1:
111108
raise NotImplementedError("Downloaded model with multiple files is not supported yet.")
112-
model_path = Path(os.path.join(download_dir, download_paths[0]))
109+
model_path = Path(download_dir) / download_paths[0]
113110
if model_path.suffix.lower() == ".pkl":
114111
return joblib.load(model_path)
115112
if model_path.suffix.lower() == ".ts":

tests/integrations/test_cloud.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,14 @@ def test_pickle_mixin_push_and_pull():
242242
_cleanup_model(teamspace, model_name, expected_num_versions=1)
243243

244244

245-
class DummyTorchModel(torch.nn.Module, PyTorchRegistryMixin):
246-
def __init__(self, input_size=784):
245+
# This is a dummy model for PyTorch that uses the PyTorchRegistryMixin.
246+
# This mixin has to be first in the inheritance order.
247+
# Otherwise, `PyTorchRegistryMixin.__init__` need to be called explicitly.
248+
class DummyTorchModel(PyTorchRegistryMixin, torch.nn.Module):
249+
def __init__(self, input_size: int, output_size: int = 10):
250+
# PyTorchRegistryMixin.__init__ will capture these arguments
247251
super().__init__()
248-
self.fc = torch.nn.Linear(input_size, 10)
252+
self.fc = torch.nn.Linear(input_size, output_size)
249253

250254
def forward(self, x):
251255
x = x.view(x.size(0), -1)

tests/integrations/test_mixins.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from unittest import mock
22

3+
import pytest
34
import torch
45
from litmodels.integrations.mixins import PickleRegistryMixin, PyTorchRegistryMixin
56
from torch import nn
@@ -13,15 +14,15 @@ def __eq__(self, other):
1314
return isinstance(other, DummyModel) and self.value == other.value
1415

1516

16-
@mock.patch("litmodels.integrations.mixins.upload_model")
17-
@mock.patch("litmodels.integrations.mixins.download_model")
17+
@mock.patch("litmodels.integrations.mixins.upload_model_files")
18+
@mock.patch("litmodels.integrations.mixins.download_model_files")
1819
def test_pickle_push_and_pull(mock_download_model, mock_upload_model, tmp_path):
1920
# Create an instance of DummyModel and call push_to_registry.
2021
dummy = DummyModel(42)
2122
dummy.push_to_registry(version="v1", temp_folder=str(tmp_path))
2223
# The expected registry name is "dummy_model:v1" and the file should be placed in the temp folder.
2324
expected_path = tmp_path / "DummyModel.pkl"
24-
mock_upload_model.assert_called_once_with(name="DummyModel:v1", model=expected_path)
25+
mock_upload_model.assert_called_once_with(name="DummyModel:v1", path=expected_path)
2526

2627
# Set the mock to return the full path to the pickle file.
2728
mock_download_model.return_value = ["DummyModel.pkl"]
@@ -31,36 +32,52 @@ def test_pickle_push_and_pull(mock_download_model, mock_upload_model, tmp_path):
3132
assert loaded_dummy.value == 42
3233

3334

34-
class DummyTorchModel(nn.Module, PyTorchRegistryMixin):
35-
def __init__(self, input_size=784):
35+
class DummyTorchModelFirst(PyTorchRegistryMixin, nn.Module):
36+
def __init__(self, input_size: int, output_size: int = 10):
37+
# PyTorchRegistryMixin.__init__ will capture these arguments
3638
super().__init__()
37-
self.fc = nn.Linear(input_size, 10)
39+
self.fc = nn.Linear(input_size, output_size)
3840

3941
def forward(self, x):
4042
x = x.view(x.size(0), -1)
4143
return self.fc(x)
4244

4345

44-
@mock.patch("litmodels.integrations.mixins.upload_model")
45-
@mock.patch("litmodels.integrations.mixins.download_model")
46-
def test_pytorch_push_and_pull(mock_download_model, mock_upload_model, tmp_path):
46+
class DummyTorchModelSecond(nn.Module, PyTorchRegistryMixin):
47+
def __init__(self, input_size: int, output_size: int = 10):
48+
PyTorchRegistryMixin.__init__(input_size, output_size)
49+
super().__init__()
50+
self.fc = nn.Linear(input_size, output_size)
51+
52+
def forward(self, x):
53+
x = x.view(x.size(0), -1)
54+
return self.fc(x)
55+
56+
57+
@pytest.mark.parametrize("torch_class", [DummyTorchModelFirst, DummyTorchModelSecond])
58+
@mock.patch("litmodels.integrations.mixins.upload_model_files")
59+
@mock.patch("litmodels.integrations.mixins.download_model_files")
60+
def test_pytorch_push_and_pull(mock_download_model, mock_upload_model, torch_class, tmp_path):
4761
# Create an instance, push the model and record its forward output.
48-
dummy = DummyTorchModel(784)
62+
dummy = torch_class(784)
4963
dummy.eval()
5064
input_tensor = torch.randn(1, 784)
5165
output_before = dummy(input_tensor)
5266

5367
dummy.push_to_registry(temp_folder=str(tmp_path))
54-
expected_path = tmp_path / f"{dummy.__class__.__name__}.pth"
55-
mock_upload_model.assert_called_once_with(name="DummyTorchModel", model=expected_path)
68+
mock_upload_model.assert_called_once_with(name=torch_class.__name__, path=str(tmp_path))
5669

57-
torch.save(dummy.state_dict(), expected_path)
70+
torch_file = f"{dummy.__class__.__name__}.pth"
71+
torch.save(dummy.state_dict(), tmp_path / torch_file)
72+
json_file = f"{dummy.__class__.__name__}__init_kwargs.json"
73+
with open(tmp_path / json_file, "w") as fp:
74+
fp.write('{"input_size": 784, "output_size": 10}')
5875
# Prepare mocking for pull_from_registry.
59-
mock_download_model.return_value = [f"{dummy.__class__.__name__}.pth"]
60-
loaded_dummy = DummyTorchModel.pull_from_registry(name="DummyTorchModel", temp_folder=str(tmp_path))
76+
mock_download_model.return_value = [torch_file, json_file]
77+
loaded_dummy = torch_class.pull_from_registry(name=torch_class.__name__, temp_folder=str(tmp_path))
6178
loaded_dummy.eval()
6279
output_after = loaded_dummy(input_tensor)
6380

64-
assert isinstance(loaded_dummy, DummyTorchModel)
81+
assert isinstance(loaded_dummy, torch_class)
6582
# Compare the outputs as a verification.
6683
assert torch.allclose(output_before, output_after), "Loaded model output differs from original."

0 commit comments

Comments
 (0)