33
44import joblib
55import pytest
6+ import torch
7+ import torch .jit as torch_jit
68from litmodels import download_model , load_model , upload_model
79from litmodels .io import upload_model_files
810from sklearn import svm
@@ -30,6 +32,7 @@ def test_download_wrong_model_name(name):
3032 [
3133 ("path/to/checkpoint" , "path/to/checkpoint" , False ),
3234 # (BoringModel(), "%s/BoringModel.ckpt"),
35+ (torch_jit .script (Module ()), f"%s{ os .path .sep } RecursiveScriptModule.ts" , True ),
3336 (Module (), f"%s{ os .path .sep } Module.pth" , True ),
3437 (svm .SVC (), f"%s{ os .path .sep } SVC.pkl" , 1 ),
3538 ],
@@ -68,7 +71,7 @@ def test_download_model(mock_download_model):
6871
6972
7073@mock .patch ("litmodels.io.cloud.sdk_download_model" )
71- def test_load_model (mock_download_model , tmp_path ):
74+ def test_load_model_pickle (mock_download_model , tmp_path ):
7275 # create a dummy model file
7376 model_file = tmp_path / "dummy_model.pkl"
7477 test_data = svm .SVC ()
@@ -84,3 +87,22 @@ def test_load_model(mock_download_model, tmp_path):
8487 name = "org-name/teamspace/model-name" , download_dir = str (tmp_path ), progress_bar = True
8588 )
8689 assert isinstance (model , svm .SVC )
90+
91+
92+ @mock .patch ("litmodels.io.cloud.sdk_download_model" )
93+ def test_load_model_torch_jit (mock_download_model , tmp_path ):
94+ # create a dummy model file
95+ model_file = tmp_path / "dummy_model.ts"
96+ test_data = torch_jit .script (Module ())
97+ test_data .save (model_file )
98+ mock_download_model .return_value = [str (model_file .name )]
99+
100+ # The lit-logger function is just a wrapper around the SDK function
101+ model = load_model (
102+ name = "org-name/teamspace/model-name" ,
103+ download_dir = str (tmp_path ),
104+ )
105+ mock_download_model .assert_called_once_with (
106+ name = "org-name/teamspace/model-name" , download_dir = str (tmp_path ), progress_bar = True
107+ )
108+ assert isinstance (model , torch .jit .ScriptModule )
0 commit comments