Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ class MyModel(PickleRegistryMixin):

# Create and push a model instance
model = MyModel(param1=42, param2="hello")
model.push_to_registry(name="my-org/my-team/my-model")
model.upload_model(name="my-org/my-team/my-model")

# Later, pull the model
loaded_model = MyModel.pull_from_registry(name="my-org/my-team/my-model")
loaded_model = MyModel.download_model(name="my-org/my-team/my-model")
```

### Using `PyTorchRegistryMixin`
Expand All @@ -225,8 +225,8 @@ class MyTorchModel(PyTorchRegistryMixin, torch.nn.Module):

# Create and push the model
model = MyTorchModel(input_size=784)
model.push_to_registry(name="my-org/my-team/torch-model")
model.upload_model(name="my-org/my-team/torch-model")

# Pull the model with the same architecture
loaded_model = MyTorchModel.pull_from_registry(name="my-org/my-team/torch-model")
loaded_model = MyTorchModel.download_model(name="my-org/my-team/torch-model")
```
12 changes: 6 additions & 6 deletions src/litmodels/integrations/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class ModelRegistryMixin(ABC):
"""Mixin for model registry integration."""

def push_to_registry(
def upload_model(
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
) -> None:
"""Push the model to the registry.
Expand All @@ -30,7 +30,7 @@ def push_to_registry(
"""

@classmethod
def pull_from_registry(
def download_model(
cls, name: str, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
) -> object:
"""Pull the model from the registry.
Expand Down Expand Up @@ -59,7 +59,7 @@ def _setup(
class PickleRegistryMixin(ModelRegistryMixin):
"""Mixin for pickle registry integration."""

def push_to_registry(
def upload_model(
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
) -> None:
"""Push the model to the registry.
Expand All @@ -78,7 +78,7 @@ def push_to_registry(
upload_model_files(name=name, path=pickle_path)

@classmethod
def pull_from_registry(
def download_model(
cls, name: str, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
) -> object:
"""Pull the model from the registry.
Expand Down Expand Up @@ -127,7 +127,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "torch.nn.Module":
instance.__init_kwargs = bound_args.arguments
return instance

def push_to_registry(
def upload_model(
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
) -> None:
"""Push the model to the registry.
Expand Down Expand Up @@ -171,7 +171,7 @@ def push_to_registry(
upload_model_files(name=model_registry, path=[torch_state_dict_path, init_kwargs_path])

@classmethod
def pull_from_registry(
def download_model(
cls,
name: str,
version: Optional[str] = None,
Expand Down
8 changes: 4 additions & 4 deletions tests/integrations/test_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,10 @@ def test_pickle_mixin_push_and_pull():

# Create an instance of DummyModel and call push_to_registry.
dummy = DummyModel(42)
dummy.push_to_registry(model_registry)
dummy.upload_model(model_registry)

# Call pull_from_registry and load the DummyModel instance.
loaded_dummy = DummyModel.pull_from_registry(model_registry)
loaded_dummy = DummyModel.download_model(model_registry)
# Verify that the unpickled instance has the expected value.
assert isinstance(loaded_dummy, DummyModel)
assert loaded_dummy.value == 42
Expand Down Expand Up @@ -273,9 +273,9 @@ def test_pytorch_mixin_push_and_pull():
input_tensor = torch.randn(1, 784)
output_before = dummy(input_tensor)

dummy.push_to_registry(model_registry)
dummy.upload_model(model_registry)

loaded_dummy = DummyTorchModel.pull_from_registry(model_registry)
loaded_dummy = DummyTorchModel.download_model(model_registry)
loaded_dummy.eval()
output_after = loaded_dummy(input_tensor)

Expand Down
8 changes: 4 additions & 4 deletions tests/integrations/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ def __eq__(self, other):
def test_pickle_push_and_pull(mock_download_model, mock_upload_model, tmp_path):
# Create an instance of DummyModel and call push_to_registry.
dummy = DummyModel(42)
dummy.push_to_registry(version="v1", temp_folder=str(tmp_path))
dummy.upload_model(version="v1", temp_folder=str(tmp_path))
# The expected registry name is "dummy_model:v1" and the file should be placed in the temp folder.
expected_path = tmp_path / "DummyModel.pkl"
mock_upload_model.assert_called_once_with(name="DummyModel:v1", path=expected_path)

# Set the mock to return the full path to the pickle file.
mock_download_model.return_value = ["DummyModel.pkl"]
# Call pull_from_registry and load the DummyModel instance.
loaded_dummy = DummyModel.pull_from_registry(name="dummy_model", version="v1", temp_folder=str(tmp_path))
loaded_dummy = DummyModel.download_model(name="dummy_model", version="v1", temp_folder=str(tmp_path))
# Verify that the unpickled instance has the expected value.
assert loaded_dummy.value == 42

Expand Down Expand Up @@ -71,14 +71,14 @@ def test_pytorch_push_and_pull(mock_download_model, mock_upload_model, torch_cla
with open(json_path, "w") as fp:
fp.write('{"input_size": 784, "output_size": 10}')

dummy.push_to_registry(temp_folder=str(tmp_path))
dummy.upload_model(temp_folder=str(tmp_path))
mock_upload_model.assert_called_once_with(
name=torch_class.__name__, path=[tmp_path / f"{torch_class.__name__}.pth", json_path]
)

# Prepare mocking for pull_from_registry.
mock_download_model.return_value = [torch_file, json_file]
loaded_dummy = torch_class.pull_from_registry(name=torch_class.__name__, temp_folder=str(tmp_path))
loaded_dummy = torch_class.download_model(name=torch_class.__name__, temp_folder=str(tmp_path))
loaded_dummy.eval()
output_after = loaded_dummy(input_tensor)

Expand Down
Loading