Skip to content

Commit 8b9492a

Browse files
authored
update uploading multiple files for mixins (#71)
* update * mock * typing
1 parent b63ef1d commit 8b9492a

File tree

5 files changed

+12
-7
lines changed

5 files changed

+12
-7
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ class MyModel(PickleRegistryMixin):
192192
self.param1 = param1
193193
self.param2 = param2
194194
# Your model initialization code
195+
...
195196

196197

197198
# Create and push a model instance

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# NOTE: once we add more dependencies, consider update dependabot to check for updates
22

3-
lightning-sdk >=0.2.5
3+
lightning-sdk >=0.2.7
44
lightning-utilities
55
joblib

src/litmodels/integrations/mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def push_to_registry(
168168
model_registry = f"{name}:{version}" if version else name
169169
# todo: consider creating another temp folder and copying these two files
170170
# todo: updating SDK to support uploading just specific files
171-
upload_model_files(name=model_registry, path=temp_folder)
171+
upload_model_files(name=model_registry, path=[torch_state_dict_path, init_kwargs_path])
172172

173173
@classmethod
174174
def pull_from_registry(

src/litmodels/io/cloud.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _print_model_link(name: str, verbose: Union[bool, int]) -> None:
4242

4343
def upload_model_files(
4444
name: str,
45-
path: Union[str, Path],
45+
path: Union[str, Path, List[Union[str, Path]]],
4646
progress_bar: bool = True,
4747
cloud_account: Optional[str] = None,
4848
verbose: Union[bool, int] = 1,

tests/integrations/test_mixins.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,18 @@ def test_pytorch_push_and_pull(mock_download_model, mock_upload_model, torch_cla
6464
input_tensor = torch.randn(1, 784)
6565
output_before = dummy(input_tensor)
6666

67-
dummy.push_to_registry(temp_folder=str(tmp_path))
68-
mock_upload_model.assert_called_once_with(name=torch_class.__name__, path=str(tmp_path))
69-
7067
torch_file = f"{dummy.__class__.__name__}.pth"
7168
torch.save(dummy.state_dict(), tmp_path / torch_file)
7269
json_file = f"{dummy.__class__.__name__}__init_kwargs.json"
73-
with open(tmp_path / json_file, "w") as fp:
70+
json_path = tmp_path / json_file
71+
with open(json_path, "w") as fp:
7472
fp.write('{"input_size": 784, "output_size": 10}')
73+
74+
dummy.push_to_registry(temp_folder=str(tmp_path))
75+
mock_upload_model.assert_called_once_with(
76+
name=torch_class.__name__, path=[tmp_path / f"{torch_class.__name__}.pth", json_path]
77+
)
78+
7579
# Prepare mocking for pull_from_registry.
7680
mock_download_model.return_value = [torch_file, json_file]
7781
loaded_dummy = torch_class.pull_from_registry(name=torch_class.__name__, temp_folder=str(tmp_path))

0 commit comments

Comments
 (0)