33
44import joblib
55import pytest
6+ import torch
67import torch .jit as torch_jit
78from litmodels import download_model , load_model , upload_model
89from litmodels .io import upload_model_files
@@ -70,7 +71,7 @@ def test_download_model(mock_download_model):
7071
7172
7273@mock .patch ("litmodels.io.cloud.sdk_download_model" )
73- def test_load_model (mock_download_model , tmp_path ):
74+ def test_load_model_pickle (mock_download_model , tmp_path ):
7475 # create a dummy model file
7576 model_file = tmp_path / "dummy_model.pkl"
7677 test_data = svm .SVC ()
@@ -86,3 +87,22 @@ def test_load_model(mock_download_model, tmp_path):
8687 name = "org-name/teamspace/model-name" , download_dir = str (tmp_path ), progress_bar = True
8788 )
8889 assert isinstance (model , svm .SVC )
90+
91+
92+ @mock .patch ("litmodels.io.cloud.sdk_download_model" )
93+ def test_load_model_torch_jit (mock_download_model , tmp_path ):
94+ # create a dummy model file
95+ model_file = tmp_path / "dummy_model.pt"
96+ test_data = torch_jit .script (Module ())
97+ test_data .save (model_file )
98+ mock_download_model .return_value = [str (model_file .name )]
99+
100+ # The lit-logger function is just a wrapper around the SDK function
101+ model = load_model (
102+ name = "org-name/teamspace/model-name" ,
103+ download_dir = str (tmp_path ),
104+ )
105+ mock_download_model .assert_called_once_with (
106+ name = "org-name/teamspace/model-name" , download_dir = str (tmp_path ), progress_bar = True
107+ )
108+ assert isinstance (model , torch .jit .ScriptModule )
0 commit comments