Skip to content

Commit bf8c801

Browse files
committed
load
1 parent 4db20ee commit bf8c801

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/litmodels/io/gateway.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,6 @@ def load_model(name: str, download_dir: str = ".") -> Any:
112112
model_path = Path(os.path.join(download_dir, download_paths[0]))
113113
if model_path.suffix.lower() == ".pkl":
114114
return joblib.load(model_path)
115+
if model_path.suffix.lower() == ".pt":
116+
return torch.jit.load(model_path)
115117
raise NotImplementedError(f"Loading model from {model_path.suffix} is not supported yet.")

tests/test_io_cloud.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import joblib
55
import pytest
6+
import torch
67
import torch.jit as torch_jit
78
from litmodels import download_model, load_model, upload_model
89
from litmodels.io import upload_model_files
@@ -70,7 +71,7 @@ def test_download_model(mock_download_model):
7071

7172

7273
@mock.patch("litmodels.io.cloud.sdk_download_model")
73-
def test_load_model(mock_download_model, tmp_path):
74+
def test_load_model_pickle(mock_download_model, tmp_path):
7475
# create a dummy model file
7576
model_file = tmp_path / "dummy_model.pkl"
7677
test_data = svm.SVC()
@@ -86,3 +87,22 @@ def test_load_model(mock_download_model, tmp_path):
8687
name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True
8788
)
8889
assert isinstance(model, svm.SVC)
90+
91+
92+
@mock.patch("litmodels.io.cloud.sdk_download_model")
93+
def test_load_model_torch_jit(mock_download_model, tmp_path):
94+
# create a dummy model file
95+
model_file = tmp_path / "dummy_model.pt"
96+
test_data = torch_jit.script(Module())
97+
test_data.save(model_file)
98+
mock_download_model.return_value = [str(model_file.name)]
99+
100+
# The lit-logger function is just a wrapper around the SDK function
101+
model = load_model(
102+
name="org-name/teamspace/model-name",
103+
download_dir=str(tmp_path),
104+
)
105+
mock_download_model.assert_called_once_with(
106+
name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True
107+
)
108+
assert isinstance(model, torch.jit.ScriptModule)

0 commit comments

Comments
 (0)