Skip to content

Commit 948881e

Browse files
Fix OpenVINO inference for legacy models (#2450)
* bug fix for legacy openvino models * Add tests * Specific exceptions ---------
1 parent 498bd85 commit 948881e

File tree

3 files changed

+106
-6
lines changed

3 files changed

+106
-6
lines changed

src/otx/algorithms/anomaly/tasks/openvino.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import numpy as np
2727
import openvino.runtime as ov
2828
from addict import Dict as ADDict
29+
from anomalib.data.utils.transform import get_transforms
2930
from anomalib.deploy import OpenVINOInferencer
3031
from nncf.common.quantization.structs import QuantizationPreset
3132
from omegaconf import OmegaConf
@@ -216,16 +217,47 @@ def get_metadata(self) -> Dict:
216217
"""Get Meta Data."""
217218
metadata = {}
218219
if self.task_environment.model is not None:
219-
metadata = json.loads(self.task_environment.model.get_data("metadata").decode())
220-
metadata["image_threshold"] = np.array(metadata["image_threshold"], dtype=np.float32).item()
221-
metadata["pixel_threshold"] = np.array(metadata["pixel_threshold"], dtype=np.float32).item()
222-
metadata["min"] = np.array(metadata["min"], dtype=np.float32).item()
223-
metadata["max"] = np.array(metadata["max"], dtype=np.float32).item()
220+
try:
221+
metadata = json.loads(self.task_environment.model.get_data("metadata").decode())
222+
self._populate_metadata(metadata)
223+
logger.info("Metadata loaded from model v1.4.")
224+
except (KeyError, json.decoder.JSONDecodeError):
225+
# model is from version 1.2.x
226+
metadata = self._populate_metadata_legacy(self.task_environment.model)
227+
logger.info("Metadata loaded from model v1.2.x.")
224228
else:
225229
raise ValueError("Cannot access meta-data. self.task_environment.model is empty.")
226230

227231
return metadata
228232

233+
def _populate_metadata_legacy(self, model: ModelEntity) -> Dict[str, Any]:
234+
"""Populates metadata for models for version 1.2.x."""
235+
image_threshold = np.frombuffer(model.get_data("image_threshold"), dtype=np.float32)
236+
pixel_threshold = np.frombuffer(model.get_data("pixel_threshold"), dtype=np.float32)
237+
min_value = np.frombuffer(model.get_data("min"), dtype=np.float32)
238+
max_value = np.frombuffer(model.get_data("max"), dtype=np.float32)
239+
transform = get_transforms(
240+
config=self.config.dataset.transform_config.train,
241+
image_size=tuple(self.config.dataset.image_size),
242+
to_tensor=True,
243+
)
244+
metadata = {
245+
"transform": transform.to_dict(),
246+
"image_threshold": image_threshold,
247+
"pixel_threshold": pixel_threshold,
248+
"min": min_value,
249+
"max": max_value,
250+
"task": str(self.task_type).lower().split("_")[-1],
251+
}
252+
return metadata
253+
254+
def _populate_metadata(self, metadata: Dict[str, Any]):
255+
"""Populates metadata for models from version 1.4 onwards."""
256+
metadata["image_threshold"] = np.array(metadata["image_threshold"], dtype=np.float32).item()
257+
metadata["pixel_threshold"] = np.array(metadata["pixel_threshold"], dtype=np.float32).item()
258+
metadata["min"] = np.array(metadata["min"], dtype=np.float32).item()
259+
metadata["max"] = np.array(metadata["max"], dtype=np.float32).item()
260+
229261
def evaluate(self, output_resultset: ResultSetEntity, evaluation_metric: Optional[str] = None):
230262
"""Evaluate the performance of the model.
231263

src/otx/cli/utils/io.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
"visual_prompting_image_encoder.bin",
5252
"visual_prompting_decoder.xml",
5353
"visual_prompting_decoder.bin",
54+
"image_threshold", # NOTE: used for compatibility with with OTX 1.2.x. Remove when all Geti projects are upgraded.
55+
"pixel_threshold", # NOTE: used for compatibility with with OTX 1.2.x. Remove when all Geti projects are upgraded.
56+
"min", # NOTE: used for compatibility with with OTX 1.2.x. Remove when all Geti projects are upgraded.
57+
"max", # NOTE: used for compatibility with with OTX 1.2.x. Remove when all Geti projects are upgraded.
5458
)
5559

5660

tests/unit/algorithms/anomaly/tasks/test_openvino.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,40 @@
33
# Copyright (C) 2021-2023 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

6-
import pytest
6+
import json
77
from copy import deepcopy
8+
from pathlib import Path
9+
from tempfile import TemporaryDirectory
10+
from unittest.mock import MagicMock, patch
811

912
import numpy as np
13+
import pytest
1014

1115
from otx.algorithms.anomaly.tasks.openvino import OpenVINOTask
1216
from otx.algorithms.anomaly.tasks.train import TrainingTask
1317
from otx.api.entities.datasets import DatasetEntity
1418
from otx.api.entities.inference_parameters import InferenceParameters
19+
from otx.api.entities.label import Domain, LabelEntity
20+
from otx.api.entities.label_schema import LabelSchemaEntity
1521
from otx.api.entities.model import ModelEntity, ModelOptimizationType
1622
from otx.api.entities.model_template import TaskType
1723
from otx.api.entities.optimization_parameters import OptimizationParameters
1824
from otx.api.entities.resultset import ResultSetEntity
1925
from otx.api.entities.subset import Subset
26+
from otx.api.entities.task_environment import TaskEnvironment
2027
from otx.api.usecases.tasks.interfaces.export_interface import ExportType
2128
from otx.api.usecases.tasks.interfaces.optimization_interface import OptimizationType
29+
from otx.cli.utils.io import read_model
2230

2331

2432
class TestOpenVINOTask:
2533
"""Tests methods in the OpenVINO task."""
2634

35+
@pytest.fixture
36+
def tmp_dir(self):
37+
with TemporaryDirectory() as tmp_dir:
38+
yield tmp_dir
39+
2740
def set_normalization_params(self, output_model: ModelEntity):
2841
"""Sets normalization parameters for an untrained output model.
2942
@@ -77,3 +90,54 @@ def test_openvino(self, tmpdir, setup_task_environment):
7790
# deploy
7891
openvino_task.deploy(output_model)
7992
assert output_model.exportable_code is not None
93+
94+
@patch.multiple(OpenVINOTask, get_config=MagicMock(), load_inferencer=MagicMock())
95+
@patch("otx.algorithms.anomaly.tasks.openvino.get_transforms", MagicMock())
96+
def test_anomaly_legacy_keys(self, mocker, tmp_dir):
97+
"""Checks whether the model is loaded correctly with legacy and current keys."""
98+
99+
tmp_dir = Path(tmp_dir)
100+
xml_model_path = tmp_dir / "model.xml"
101+
xml_model_path.write_text("xml_model")
102+
bin_model_path = tmp_dir / "model.bin"
103+
bin_model_path.write_text("bin_model")
104+
105+
# Test loading legacy keys
106+
legacy_keys = ("image_threshold", "pixel_threshold", "min", "max")
107+
for key in legacy_keys:
108+
(tmp_dir / key).write_bytes(np.zeros(1, dtype=np.float32).tobytes())
109+
110+
model = read_model(mocker.MagicMock(), str(xml_model_path), mocker.MagicMock())
111+
task_environment = TaskEnvironment(
112+
model_template=mocker.MagicMock(),
113+
model=model,
114+
hyper_parameters=mocker.MagicMock(),
115+
label_schema=LabelSchemaEntity.from_labels(
116+
[
117+
LabelEntity("Anomalous", is_anomalous=True, domain=Domain.ANOMALY_SEGMENTATION),
118+
LabelEntity("Normal", domain=Domain.ANOMALY_SEGMENTATION),
119+
]
120+
),
121+
)
122+
openvino_task = OpenVINOTask(task_environment)
123+
metadata = openvino_task.get_metadata()
124+
for key in legacy_keys:
125+
assert metadata[key] == np.zeros(1, dtype=np.float32)
126+
127+
# cleanup legacy keys
128+
for key in legacy_keys:
129+
(tmp_dir / key).unlink()
130+
131+
# Test loading new keys
132+
new_metadata = {
133+
"image_threshold": np.zeros(1, dtype=np.float32).tolist(),
134+
"pixel_threshold": np.zeros(1, dtype=np.float32).tolist(),
135+
"min": np.zeros(1, dtype=np.float32).tolist(),
136+
"max": np.zeros(1, dtype=np.float32).tolist(),
137+
}
138+
(tmp_dir / "metadata").write_bytes(json.dumps(new_metadata).encode())
139+
task_environment.model = read_model(mocker.MagicMock(), str(xml_model_path), mocker.MagicMock())
140+
openvino_task = OpenVINOTask(task_environment)
141+
metadata = openvino_task.get_metadata()
142+
for key in new_metadata.keys():
143+
assert metadata[key] == np.zeros(1, dtype=np.float32)

0 commit comments

Comments
 (0)