Skip to content

Commit da3c39b

Browse files
authored
Refactor model saving (#231)
* Move adapter-specific configuration logic to adapter interface * Update implementation * Add tests for saving via onnx adapter * Fix linter * Restore create model from adapter * Update functional tests * Fix imports order * Update precommit data deps * Fix data preparation script * Fix onnx models loading * Fix saving logic
1 parent 0da8dc4 commit da3c39b

File tree

11 files changed

+150
-41
lines changed

11 files changed

+150
-41
lines changed

.github/workflows/test_precommit.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
- name: Run test
3636
run: |
3737
source venv/bin/activate
38-
pytest tests/python/funtional
38+
pytest --data=./data tests/python/funtional
3939
CPP-Code-Quality:
4040
name: CPP-Code-Quality
4141
runs-on: ubuntu-latest

model_api/python/model_api/adapters/inference_adapter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,14 @@ def await_any(self):
163163
def get_rt_info(self, path):
164164
"""Forwards to openvino.Model.get_rt_info(path)"""
165165

166+
@abstractmethod
167+
def update_model_info(self, model_info: dict[str, Any]):
168+
"""Updates model with the provided model info."""
169+
170+
@abstractmethod
171+
def save_model(self, path: str, weights_path: str, version: str):
172+
"""Serializes model to the filesystem."""
173+
166174
@abstractmethod
167175
def embed_preprocessing(
168176
self,

model_api/python/model_api/adapters/onnx_adapter.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import sys
99
from functools import partial, reduce
10+
from typing import Any
1011

1112
import numpy as np
1213

@@ -55,6 +56,7 @@ def __init__(self, model: str, ort_options: dict = {}):
5556
inferred_model.SerializeToString(),
5657
**ort_options,
5758
)
59+
self.model = inferred_model
5860
self.output_names = [o.name for o in self.session.get_outputs()]
5961
self.onnx_metadata = load_parameters_from_onnx(inferred_model)
6062
self.preprocessor = lambda arg: arg
@@ -169,14 +171,27 @@ def embed_preprocessing(
169171

170172
def get_model(self):
171173
"""Return the reference to the ONNXRuntime session."""
172-
return self.session
174+
return self.model
173175

174176
def reshape_model(self, new_shape):
175177
raise NotImplementedError
176178

177179
def get_rt_info(self, path):
178180
return get_rt_info_from_dict(self.onnx_metadata, path)
179181

182+
def update_model_info(self, model_info: dict[str, Any]):
183+
for item in model_info:
184+
meta = self.model.metadata_props.add()
185+
attr_path = "model_info " + item
186+
meta.key = attr_path.strip()
187+
if isinstance(model_info[item], list):
188+
meta.value = " ".join(str(x) for x in model_info[item])
189+
else:
190+
meta.value = str(model_info[item])
191+
192+
def save_model(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"):
193+
onnx.save(self.model, path)
194+
180195

181196
_onnx2ov_precision = {
182197
"tensor(float)": "f32",

model_api/python/model_api/adapters/openvino_adapter.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,19 @@ def get_model(self):
419419
"""
420420
return self.model
421421

422+
def update_model_info(self, model_info: dict[str, Any]):
423+
"""
424+
Populates OV IR RT info with the given model info.
425+
426+
Args:
427+
model_info (dict[str, Any]): a dict representing the serialized parameters.
428+
"""
429+
for name in model_info:
430+
self.model.set_rt_info(model_info[name], ["model_info", name])
431+
432+
def save_model(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"):
433+
ov.serialize(self.get_model(), path, weights_path, version)
434+
422435

423436
def get_input_shape(input_tensor: ov.Output) -> list[int]:
424437
def string_to_tuple(string, casting_type=int):

model_api/python/model_api/adapters/ovms_adapter.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#
55

66
import re
7+
from typing import Any
78

89
import numpy as np
910

@@ -122,6 +123,14 @@ def get_rt_info(self, path):
122123
msg = "OVMSAdapter does not support RT info getting"
123124
raise NotImplementedError(msg)
124125

126+
def update_model_info(self, model_info: dict[str, Any]):
127+
msg = "OVMSAdapter does not support updating model info"
128+
raise NotImplementedError(msg)
129+
130+
def save_model(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"):
131+
msg = "OVMSAdapter does not support saving a model"
132+
raise NotImplementedError(msg)
133+
125134

126135
_tf2ov_precision = {
127136
"DT_INT64": "I64",

model_api/python/model_api/models/model.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from contextlib import contextmanager
1111
from typing import TYPE_CHECKING, Any, NoReturn, Type
1212

13+
from model_api.adapters.inference_adapter import InferenceAdapter
1314
from model_api.adapters.onnx_adapter import ONNXRuntimeAdapter
1415
from model_api.adapters.openvino_adapter import (
1516
OpenvinoAdapter,
@@ -23,8 +24,6 @@
2324

2425
from numpy import ndarray
2526

26-
from model_api.adapters.inference_adapter import InferenceAdapter
27-
2827

2928
class WrapperError(Exception):
3029
"""The class for errors occurred in Model API wrappers"""
@@ -100,11 +99,7 @@ def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {}
10099
self.callback_fn = lambda _: None
101100

102101
def get_model(self):
103-
model = self.inference_adapter.get_model()
104-
model.set_rt_info(self.__model__, ["model_info", "model_type"])
105-
for name in self.parameters():
106-
model.set_rt_info(getattr(self, name), ["model_info", name])
107-
return model
102+
return self.inference_adapter.get_model()
108103

109104
@classmethod
110105
def get_model_class(cls, name: str) -> Type:
@@ -122,7 +117,7 @@ def get_model_class(cls, name: str) -> Type:
122117
@classmethod
123118
def create_model(
124119
cls,
125-
model: str,
120+
model: str | InferenceAdapter,
126121
model_type: Any | None = None,
127122
configuration: dict[str, Any] = {},
128123
preload: bool = True,
@@ -140,7 +135,7 @@ def create_model(
140135
"""Create an instance of the Model API model
141136
142137
Args:
143-
model (str): model name from OpenVINO Model Zoo, path to model, OVMS URL
138+
model (str| InferenceAdapter): model name from OpenVINO Model Zoo, path to model, OVMS URL, or an adapter
144139
configuration (:obj:`dict`, optional): dictionary of model config with model properties, for example
145140
confidence_threshold, labels
146141
model_type (:obj:`str`, optional): name of model wrapper to create (e.g. "ssd")
@@ -162,7 +157,9 @@ def create_model(
162157
Model object
163158
"""
164159
inference_adapter: InferenceAdapter
165-
if isinstance(model, str) and re.compile(
160+
if isinstance(model, InferenceAdapter):
161+
inference_adapter = model
162+
elif isinstance(model, str) and re.compile(
166163
r"(\w+\.*\-*)*\w+:\d+\/models\/[a-zA-Z0-9._-]+(\:\d+)*",
167164
).fullmatch(model):
168165
inference_adapter = OVMSAdapter(model)
@@ -487,7 +484,12 @@ def log_layers_info(self):
487484
f"precision: {metadata.precision}, layout: {metadata.layout}",
488485
)
489486

490-
def save(self, xml_path, bin_path="", version="UNSPECIFIED"):
491-
import openvino
487+
def save(self, path: str, weights_path: str = "", version: str = "UNSPECIFIED"):
488+
model_info = {
489+
"model_type": self.__model__,
490+
}
491+
for name in self.parameters():
492+
model_info[name] = getattr(self, name)
492493

493-
openvino.serialize(self.get_model(), xml_path, bin_path, version)
494+
self.inference_adapter.update_model_info(model_info)
495+
self.inference_adapter.save_model(path, weights_path, version)

tests/cpp/precommit/prepare_data.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,23 @@
55
from urllib.request import urlopen, urlretrieve
66

77

8-
def retrieve_otx_model(data_dir, model_name):
9-
destenation_folder = os.path.join(data_dir, "otx_models")
10-
os.makedirs(destenation_folder, exist_ok=True)
11-
urlretrieve(
12-
f"https://storage.openvinotoolkit.org/repositories/model_api/test/otx_models/{model_name}/openvino.xml",
13-
f"{destenation_folder}/{model_name}.xml",
14-
)
15-
urlretrieve(
16-
f"https://storage.openvinotoolkit.org/repositories/model_api/test/otx_models/{model_name}/openvino.bin",
17-
f"{destenation_folder}/{model_name}.bin",
18-
)
8+
def retrieve_otx_model(data_dir, model_name, format="xml"):
9+
destination_folder = os.path.join(data_dir, "otx_models")
10+
os.makedirs(destination_folder, exist_ok=True)
11+
if format == "onnx":
12+
urlretrieve(
13+
f"https://storage.openvinotoolkit.org/repositories/model_api/test/otx_models/{model_name}/model.onnx",
14+
f"{destination_folder}/{model_name}.onnx",
15+
)
16+
else:
17+
urlretrieve(
18+
f"https://storage.openvinotoolkit.org/repositories/model_api/test/otx_models/{model_name}/openvino.xml",
19+
f"{destination_folder}/{model_name}.xml",
20+
)
21+
urlretrieve(
22+
f"https://storage.openvinotoolkit.org/repositories/model_api/test/otx_models/{model_name}/openvino.bin",
23+
f"{destination_folder}/{model_name}.bin",
24+
)
1925

2026

2127
def prepare_model(
@@ -30,7 +36,7 @@ def prepare_model(
3036
public_scope = json.load(f)
3137

3238
for model in public_scope:
33-
if model["name"].endswith(".xml"):
39+
if model["name"].endswith(".xml") or model["name"].endswith(".onnx"):
3440
continue
3541
model = eval(model["type"]).create_model(model["name"], download_dir=data_dir)
3642

@@ -72,3 +78,4 @@ def prepare_data(data_dir="./data"):
7278
prepare_data(args.data_dir)
7379
retrieve_otx_model(args.data_dir, "mlc_mobilenetv3_large_voc")
7480
retrieve_otx_model(args.data_dir, "tinynet_imagenet")
81+
retrieve_otx_model(args.data_dir, "cls_mobilenetv3_large_cars", "onnx")

tests/python/funtional/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#
2+
# Copyright (C) 2024 Intel Corporation
3+
# SPDX-License-Identifier: Apache-2.0
4+
#

tests/python/funtional/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#
2+
# Copyright (C) 2024 Intel Corporation
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
6+
import pytest
7+
8+
9+
def pytest_addoption(parser):
10+
parser.addoption("--data", action="store", help="data folder with dataset")
11+
12+
13+
@pytest.fixture(scope="session")
14+
def data(pytestconfig):
15+
return pytestconfig.getoption("data")

tests/python/funtional/test_load.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
# Copyright (C) 2020-2024 Intel Corporation
33
# SPDX-License-Identifier: Apache-2.0
44
#
5+
6+
from pathlib import Path
7+
58
from model_api.models import Model
69

710

8-
def test_model_with_unnamed_output_load():
11+
def test_model_with_unnamed_output_load(data):
912
# the model's output doesn't have a name
1013
_ = Model.create_model(
11-
"data/otx_models/tinynet_imagenet.xml",
14+
Path(data) / "otx_models/tinynet_imagenet.xml",
1215
model_type="Classification",
1316
preload=True,
1417
)

0 commit comments

Comments
 (0)