Skip to content

Commit 6b7142f

Browse files
authored
update mixin's methods (#76)
1 parent 643fe28 commit 6b7142f

File tree

4 files changed

+18
-18
lines changed

4 files changed

+18
-18
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,10 @@ class MyModel(PickleRegistryMixin):
197197

198198
# Create and push a model instance
199199
model = MyModel(param1=42, param2="hello")
200-
model.push_to_registry(name="my-org/my-team/my-model")
200+
model.upload_model(name="my-org/my-team/my-model")
201201

202202
# Later, pull the model
203-
loaded_model = MyModel.pull_from_registry(name="my-org/my-team/my-model")
203+
loaded_model = MyModel.download_model(name="my-org/my-team/my-model")
204204
```
205205

206206
### Using `PyTorchRegistryMixin`
@@ -225,8 +225,8 @@ class MyTorchModel(PyTorchRegistryMixin, torch.nn.Module):
225225

226226
# Create and push the model
227227
model = MyTorchModel(input_size=784)
228-
model.push_to_registry(name="my-org/my-team/torch-model")
228+
model.upload_model(name="my-org/my-team/torch-model")
229229

230230
# Pull the model with the same architecture
231-
loaded_model = MyTorchModel.pull_from_registry(name="my-org/my-team/torch-model")
231+
loaded_model = MyTorchModel.download_model(name="my-org/my-team/torch-model")
232232
```

src/litmodels/integrations/mixins.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
class ModelRegistryMixin(ABC):
1919
"""Mixin for model registry integration."""
2020

21-
def push_to_registry(
21+
def upload_model(
2222
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
2323
) -> None:
2424
"""Push the model to the registry.
@@ -30,7 +30,7 @@ def push_to_registry(
3030
"""
3131

3232
@classmethod
33-
def pull_from_registry(
33+
def download_model(
3434
cls, name: str, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
3535
) -> object:
3636
"""Pull the model from the registry.
@@ -59,7 +59,7 @@ def _setup(
5959
class PickleRegistryMixin(ModelRegistryMixin):
6060
"""Mixin for pickle registry integration."""
6161

62-
def push_to_registry(
62+
def upload_model(
6363
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
6464
) -> None:
6565
"""Push the model to the registry.
@@ -78,7 +78,7 @@ def push_to_registry(
7878
upload_model_files(name=name, path=pickle_path)
7979

8080
@classmethod
81-
def pull_from_registry(
81+
def download_model(
8282
cls, name: str, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
8383
) -> object:
8484
"""Pull the model from the registry.
@@ -127,7 +127,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "torch.nn.Module":
127127
instance.__init_kwargs = bound_args.arguments
128128
return instance
129129

130-
def push_to_registry(
130+
def upload_model(
131131
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
132132
) -> None:
133133
"""Push the model to the registry.
@@ -171,7 +171,7 @@ def push_to_registry(
171171
upload_model_files(name=model_registry, path=[torch_state_dict_path, init_kwargs_path])
172172

173173
@classmethod
174-
def pull_from_registry(
174+
def download_model(
175175
cls,
176176
name: str,
177177
version: Optional[str] = None,

tests/integrations/test_cloud.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,10 @@ def test_pickle_mixin_push_and_pull():
235235

236236
# Create an instance of DummyModel and call push_to_registry.
237237
dummy = DummyModel(42)
238-
dummy.push_to_registry(model_registry)
238+
dummy.upload_model(model_registry)
239239

240240
# Call pull_from_registry and load the DummyModel instance.
241-
loaded_dummy = DummyModel.pull_from_registry(model_registry)
241+
loaded_dummy = DummyModel.download_model(model_registry)
242242
# Verify that the unpickled instance has the expected value.
243243
assert isinstance(loaded_dummy, DummyModel)
244244
assert loaded_dummy.value == 42
@@ -273,9 +273,9 @@ def test_pytorch_mixin_push_and_pull():
273273
input_tensor = torch.randn(1, 784)
274274
output_before = dummy(input_tensor)
275275

276-
dummy.push_to_registry(model_registry)
276+
dummy.upload_model(model_registry)
277277

278-
loaded_dummy = DummyTorchModel.pull_from_registry(model_registry)
278+
loaded_dummy = DummyTorchModel.download_model(model_registry)
279279
loaded_dummy.eval()
280280
output_after = loaded_dummy(input_tensor)
281281

tests/integrations/test_mixins.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ def __eq__(self, other):
1919
def test_pickle_push_and_pull(mock_download_model, mock_upload_model, tmp_path):
2020
# Create an instance of DummyModel and call push_to_registry.
2121
dummy = DummyModel(42)
22-
dummy.push_to_registry(version="v1", temp_folder=str(tmp_path))
22+
dummy.upload_model(version="v1", temp_folder=str(tmp_path))
2323
# The expected registry name is "dummy_model:v1" and the file should be placed in the temp folder.
2424
expected_path = tmp_path / "DummyModel.pkl"
2525
mock_upload_model.assert_called_once_with(name="DummyModel:v1", path=expected_path)
2626

2727
# Set the mock to return the full path to the pickle file.
2828
mock_download_model.return_value = ["DummyModel.pkl"]
2929
# Call pull_from_registry and load the DummyModel instance.
30-
loaded_dummy = DummyModel.pull_from_registry(name="dummy_model", version="v1", temp_folder=str(tmp_path))
30+
loaded_dummy = DummyModel.download_model(name="dummy_model", version="v1", temp_folder=str(tmp_path))
3131
# Verify that the unpickled instance has the expected value.
3232
assert loaded_dummy.value == 42
3333

@@ -71,14 +71,14 @@ def test_pytorch_push_and_pull(mock_download_model, mock_upload_model, torch_cla
7171
with open(json_path, "w") as fp:
7272
fp.write('{"input_size": 784, "output_size": 10}')
7373

74-
dummy.push_to_registry(temp_folder=str(tmp_path))
74+
dummy.upload_model(temp_folder=str(tmp_path))
7575
mock_upload_model.assert_called_once_with(
7676
name=torch_class.__name__, path=[tmp_path / f"{torch_class.__name__}.pth", json_path]
7777
)
7878

7979
# Prepare mocking for pull_from_registry.
8080
mock_download_model.return_value = [torch_file, json_file]
81-
loaded_dummy = torch_class.pull_from_registry(name=torch_class.__name__, temp_folder=str(tmp_path))
81+
loaded_dummy = torch_class.download_model(name=torch_class.__name__, temp_folder=str(tmp_path))
8282
loaded_dummy.eval()
8383
output_after = loaded_dummy(input_tensor)
8484

0 commit comments

Comments
 (0)