Skip to content

Commit 4cc9d8a

Browse files
authored
dump sklearn model to a file for upload (#37)
* dump sklearn model to a file for upload * linter * Any
1 parent f172e4f commit 4cc9d8a

File tree

4 files changed

+10
-3
lines changed

4 files changed

+10
-3
lines changed

_requirements/test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ pytest-cov
44
pytest-mock
55

66
pytorch-lightning >=2.0
7+
scikit-learn >=1.0

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
lightning-sdk >=0.1.40
44
lightning-utilities
5+
joblib

src/litmodels/io/gateway.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import os
22
import tempfile
33
from pathlib import Path
4-
from typing import TYPE_CHECKING, List, Optional, Union
4+
from typing import TYPE_CHECKING, Any, List, Optional, Union
55

6+
import joblib
67
from lightning_utilities import module_available
78

89
from litmodels.io.cloud import download_model_files, upload_model_files
@@ -19,7 +20,7 @@
1920

2021
def upload_model(
2122
name: str,
22-
model: Union[str, Path, "Module"],
23+
model: Union[str, Path, "Module", Any],
2324
progress_bar: bool = True,
2425
cloud_account: Optional[str] = None,
2526
staging_dir: Optional[str] = None,
@@ -52,7 +53,9 @@ def upload_model(
5253
elif isinstance(model, Path):
5354
path = str(model)
5455
else:
55-
raise ValueError(f"Unsupported model type {type(model)}")
56+
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pkl")
57+
joblib.dump(model, path)
58+
5659
return upload_model_files(
5760
path=path,
5861
name=name,

tests/test_io_cloud.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
from litmodels import download_model, upload_model
66
from litmodels.io import upload_model_files
7+
from sklearn import svm
78
from torch.nn import Module
89

910

@@ -21,6 +22,7 @@ def test_wrong_model_name(name):
2122
("path/to/checkpoint", "path/to/checkpoint", False),
2223
# (BoringModel(), "%s/BoringModel.ckpt"),
2324
(Module(), f"%s{os.path.sep}Module.pth", True),
25+
(svm.SVC(), f"%s{os.path.sep}SVC.pkl", 1),
2426
],
2527
)
2628
@mock.patch("litmodels.io.cloud.sdk_upload_model")

0 commit comments

Comments
 (0)