Skip to content

Commit 29b633e

Browse files
committed
Fix saving logic
1 parent ced4b2c commit 29b633e

File tree

2 files changed

+29
-26
lines changed

2 files changed

+29
-26
lines changed

model_api/python/model_api/models/model.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,6 @@ def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {}
9999
self.callback_fn = lambda _: None
100100

101101
def get_model(self):
102-
model_info = {
103-
"model_type": self.__model__,
104-
}
105-
for name in self.parameters():
106-
model_info[name] = getattr(self, name)
107-
108-
self.inference_adapter.update_model_info(model_info)
109-
110102
return self.inference_adapter.get_model()
111103

112104
@classmethod
@@ -493,4 +485,11 @@ def log_layers_info(self):
493485
)
494486

495487
def save(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"):
488+
model_info = {
489+
"model_type": self.__model__,
490+
}
491+
for name in self.parameters():
492+
model_info[name] = getattr(self, name)
493+
494+
self.inference_adapter.update_model_info(model_info)
496495
self.inference_adapter.save_model(path, weights_path, version)

tests/python/funtional/test_save.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from pathlib import Path
77

8+
import onnx
9+
810
from model_api.models import Model
911
from model_api.adapters import ONNXRuntimeAdapter
1012
from model_api.adapters.utils import load_parameters_from_onnx
@@ -15,14 +17,15 @@ def test_detector_save(tmp_path):
1517
"ssd_mobilenet_v1_fpn_coco",
1618
configuration={"mean_values": [0, 0, 0], "confidence_threshold": 0.6},
1719
)
20+
xml_path = str(tmp_path / "a.xml")
21+
downloaded.save(xml_path)
22+
deserialized = Model.create_model(xml_path)
23+
1824
assert (
19-
downloaded.get_model()
25+
deserialized.get_model()
2026
.get_rt_info(["model_info", "embedded_processing"])
2127
.astype(bool)
2228
)
23-
xml_path = str(tmp_path / "a.xml")
24-
downloaded.save(xml_path)
25-
deserialized = Model.create_model(xml_path)
2629
assert type(downloaded) is type(deserialized)
2730
for attr in downloaded.parameters():
2831
assert getattr(downloaded, attr) == getattr(deserialized, attr)
@@ -32,14 +35,15 @@ def test_classifier_save(tmp_path):
3235
downloaded = Model.create_model(
3336
"efficientnet-b0-pytorch", configuration={"scale_values": [1, 1, 1], "topk": 6}
3437
)
38+
xml_path = str(tmp_path / "a.xml")
39+
downloaded.save(xml_path)
40+
deserialized = Model.create_model(xml_path)
41+
3542
assert (
36-
downloaded.get_model()
43+
deserialized.get_model()
3744
.get_rt_info(["model_info", "embedded_processing"])
3845
.astype(bool)
3946
)
40-
xml_path = str(tmp_path / "a.xml")
41-
downloaded.save(xml_path)
42-
deserialized = Model.create_model(xml_path)
4347
assert type(downloaded) is type(deserialized)
4448
for attr in downloaded.parameters():
4549
assert getattr(downloaded, attr) == getattr(deserialized, attr)
@@ -50,14 +54,15 @@ def test_segmentor_save(tmp_path):
5054
"hrnet-v2-c1-segmentation",
5155
configuration={"reverse_input_channels": True, "labels": ["first", "second"]},
5256
)
57+
xml_path = str(tmp_path / "a.xml")
58+
downloaded.save(xml_path)
59+
deserialized = Model.create_model(xml_path)
60+
5361
assert (
54-
downloaded.get_model()
62+
deserialized.get_model()
5563
.get_rt_info(["model_info", "embedded_processing"])
5664
.astype(bool)
5765
)
58-
xml_path = str(tmp_path / "a.xml")
59-
downloaded.save(xml_path)
60-
deserialized = Model.create_model(xml_path)
6166
assert type(downloaded) is type(deserialized)
6267
for attr in downloaded.parameters():
6368
assert getattr(downloaded, attr) == getattr(deserialized, attr)
@@ -71,17 +76,16 @@ def test_onnx_save(tmp_path, data):
7176
configuration={"reverse_input_channels": True, "topk": 6},
7277
)
7378

79+
onnx_path = str(tmp_path / "a.onnx")
80+
cls_model.save(onnx_path)
81+
deserialized = Model.create_model(onnx_path)
82+
7483
assert (
75-
load_parameters_from_onnx(cls_model.get_model())["model_info"][
84+
load_parameters_from_onnx(onnx.load(onnx_path))["model_info"][
7685
"embedded_processing"
7786
]
7887
== "True"
7988
)
80-
81-
onnx_path = str(tmp_path / "a.onnx")
82-
cls_model.save(onnx_path)
83-
84-
deserialized = Model.create_model(onnx_path)
8589
assert type(cls_model) is type(deserialized)
8690
for attr in cls_model.parameters():
8791
assert getattr(cls_model, attr) == getattr(deserialized, attr)

0 commit comments

Comments
 (0)