55
66from pathlib import Path
77
8+ import onnx
9+
810from model_api .models import Model
911from model_api .adapters import ONNXRuntimeAdapter
1012from model_api .adapters .utils import load_parameters_from_onnx
@@ -15,14 +17,15 @@ def test_detector_save(tmp_path):
1517 "ssd_mobilenet_v1_fpn_coco" ,
1618 configuration = {"mean_values" : [0 , 0 , 0 ], "confidence_threshold" : 0.6 },
1719 )
20+ xml_path = str (tmp_path / "a.xml" )
21+ downloaded .save (xml_path )
22+ deserialized = Model .create_model (xml_path )
23+
1824 assert (
19- downloaded .get_model ()
25+ deserialized .get_model ()
2026 .get_rt_info (["model_info" , "embedded_processing" ])
2127 .astype (bool )
2228 )
23- xml_path = str (tmp_path / "a.xml" )
24- downloaded .save (xml_path )
25- deserialized = Model .create_model (xml_path )
2629 assert type (downloaded ) is type (deserialized )
2730 for attr in downloaded .parameters ():
2831 assert getattr (downloaded , attr ) == getattr (deserialized , attr )
@@ -32,14 +35,15 @@ def test_classifier_save(tmp_path):
3235 downloaded = Model .create_model (
3336 "efficientnet-b0-pytorch" , configuration = {"scale_values" : [1 , 1 , 1 ], "topk" : 6 }
3437 )
38+ xml_path = str (tmp_path / "a.xml" )
39+ downloaded .save (xml_path )
40+ deserialized = Model .create_model (xml_path )
41+
3542 assert (
36- downloaded .get_model ()
43+ deserialized .get_model ()
3744 .get_rt_info (["model_info" , "embedded_processing" ])
3845 .astype (bool )
3946 )
40- xml_path = str (tmp_path / "a.xml" )
41- downloaded .save (xml_path )
42- deserialized = Model .create_model (xml_path )
4347 assert type (downloaded ) is type (deserialized )
4448 for attr in downloaded .parameters ():
4549 assert getattr (downloaded , attr ) == getattr (deserialized , attr )
@@ -50,14 +54,15 @@ def test_segmentor_save(tmp_path):
5054 "hrnet-v2-c1-segmentation" ,
5155 configuration = {"reverse_input_channels" : True , "labels" : ["first" , "second" ]},
5256 )
57+ xml_path = str (tmp_path / "a.xml" )
58+ downloaded .save (xml_path )
59+ deserialized = Model .create_model (xml_path )
60+
5361 assert (
54- downloaded .get_model ()
62+ deserialized .get_model ()
5563 .get_rt_info (["model_info" , "embedded_processing" ])
5664 .astype (bool )
5765 )
58- xml_path = str (tmp_path / "a.xml" )
59- downloaded .save (xml_path )
60- deserialized = Model .create_model (xml_path )
6166 assert type (downloaded ) is type (deserialized )
6267 for attr in downloaded .parameters ():
6368 assert getattr (downloaded , attr ) == getattr (deserialized , attr )
@@ -71,17 +76,16 @@ def test_onnx_save(tmp_path, data):
7176 configuration = {"reverse_input_channels" : True , "topk" : 6 },
7277 )
7378
79+ onnx_path = str (tmp_path / "a.onnx" )
80+ cls_model .save (onnx_path )
81+ deserialized = Model .create_model (onnx_path )
82+
7483 assert (
75- load_parameters_from_onnx (cls_model . get_model ( ))["model_info" ][
84+ load_parameters_from_onnx (onnx . load ( onnx_path ))["model_info" ][
7685 "embedded_processing"
7786 ]
7887 == "True"
7988 )
80-
81- onnx_path = str (tmp_path / "a.onnx" )
82- cls_model .save (onnx_path )
83-
84- deserialized = Model .create_model (onnx_path )
8589 assert type (cls_model ) is type (deserialized )
8690 for attr in cls_model .parameters ():
8791 assert getattr (cls_model , attr ) == getattr (deserialized , attr )
0 commit comments