|
3 | 3 | # Copyright (C) 2021-2023 Intel Corporation |
4 | 4 | # SPDX-License-Identifier: Apache-2.0 |
5 | 5 |
|
6 | | -import pytest |
| 6 | +import json |
7 | 7 | from copy import deepcopy |
| 8 | +from pathlib import Path |
| 9 | +from tempfile import TemporaryDirectory |
| 10 | +from unittest.mock import MagicMock, patch |
8 | 11 |
|
9 | 12 | import numpy as np |
| 13 | +import pytest |
10 | 14 |
|
11 | 15 | from otx.algorithms.anomaly.tasks.openvino import OpenVINOTask |
12 | 16 | from otx.algorithms.anomaly.tasks.train import TrainingTask |
13 | 17 | from otx.api.entities.datasets import DatasetEntity |
14 | 18 | from otx.api.entities.inference_parameters import InferenceParameters |
| 19 | +from otx.api.entities.label import Domain, LabelEntity |
| 20 | +from otx.api.entities.label_schema import LabelSchemaEntity |
15 | 21 | from otx.api.entities.model import ModelEntity, ModelOptimizationType |
16 | 22 | from otx.api.entities.model_template import TaskType |
17 | 23 | from otx.api.entities.optimization_parameters import OptimizationParameters |
18 | 24 | from otx.api.entities.resultset import ResultSetEntity |
19 | 25 | from otx.api.entities.subset import Subset |
| 26 | +from otx.api.entities.task_environment import TaskEnvironment |
20 | 27 | from otx.api.usecases.tasks.interfaces.export_interface import ExportType |
21 | 28 | from otx.api.usecases.tasks.interfaces.optimization_interface import OptimizationType |
| 29 | +from otx.cli.utils.io import read_model |
22 | 30 |
|
23 | 31 |
|
24 | 32 | class TestOpenVINOTask: |
25 | 33 | """Tests methods in the OpenVINO task.""" |
26 | 34 |
|
| 35 | + @pytest.fixture |
| 36 | + def tmp_dir(self): |
| 37 | + with TemporaryDirectory() as tmp_dir: |
| 38 | + yield tmp_dir |
| 39 | + |
27 | 40 | def set_normalization_params(self, output_model: ModelEntity): |
28 | 41 | """Sets normalization parameters for an untrained output model. |
29 | 42 |
|
@@ -77,3 +90,54 @@ def test_openvino(self, tmpdir, setup_task_environment): |
77 | 90 | # deploy |
78 | 91 | openvino_task.deploy(output_model) |
79 | 92 | assert output_model.exportable_code is not None |
| 93 | + |
| 94 | + @patch.multiple(OpenVINOTask, get_config=MagicMock(), load_inferencer=MagicMock()) |
| 95 | + @patch("otx.algorithms.anomaly.tasks.openvino.get_transforms", MagicMock()) |
| 96 | + def test_anomaly_legacy_keys(self, mocker, tmp_dir): |
| 97 | + """Checks whether the model is loaded correctly with legacy and current keys.""" |
| 98 | + |
| 99 | + tmp_dir = Path(tmp_dir) |
| 100 | + xml_model_path = tmp_dir / "model.xml" |
| 101 | + xml_model_path.write_text("xml_model") |
| 102 | + bin_model_path = tmp_dir / "model.bin" |
| 103 | + bin_model_path.write_text("bin_model") |
| 104 | + |
| 105 | + # Test loading legacy keys |
| 106 | + legacy_keys = ("image_threshold", "pixel_threshold", "min", "max") |
| 107 | + for key in legacy_keys: |
| 108 | + (tmp_dir / key).write_bytes(np.zeros(1, dtype=np.float32).tobytes()) |
| 109 | + |
| 110 | + model = read_model(mocker.MagicMock(), str(xml_model_path), mocker.MagicMock()) |
| 111 | + task_environment = TaskEnvironment( |
| 112 | + model_template=mocker.MagicMock(), |
| 113 | + model=model, |
| 114 | + hyper_parameters=mocker.MagicMock(), |
| 115 | + label_schema=LabelSchemaEntity.from_labels( |
| 116 | + [ |
| 117 | + LabelEntity("Anomalous", is_anomalous=True, domain=Domain.ANOMALY_SEGMENTATION), |
| 118 | + LabelEntity("Normal", domain=Domain.ANOMALY_SEGMENTATION), |
| 119 | + ] |
| 120 | + ), |
| 121 | + ) |
| 122 | + openvino_task = OpenVINOTask(task_environment) |
| 123 | + metadata = openvino_task.get_metadata() |
| 124 | + for key in legacy_keys: |
| 125 | + assert metadata[key] == np.zeros(1, dtype=np.float32) |
| 126 | + |
| 127 | + # cleanup legacy keys |
| 128 | + for key in legacy_keys: |
| 129 | + (tmp_dir / key).unlink() |
| 130 | + |
| 131 | + # Test loading new keys |
| 132 | + new_metadata = { |
| 133 | + "image_threshold": np.zeros(1, dtype=np.float32).tolist(), |
| 134 | + "pixel_threshold": np.zeros(1, dtype=np.float32).tolist(), |
| 135 | + "min": np.zeros(1, dtype=np.float32).tolist(), |
| 136 | + "max": np.zeros(1, dtype=np.float32).tolist(), |
| 137 | + } |
| 138 | + (tmp_dir / "metadata").write_bytes(json.dumps(new_metadata).encode()) |
| 139 | + task_environment.model = read_model(mocker.MagicMock(), str(xml_model_path), mocker.MagicMock()) |
| 140 | + openvino_task = OpenVINOTask(task_environment) |
| 141 | + metadata = openvino_task.get_metadata() |
| 142 | + for key in new_metadata.keys(): |
| 143 | + assert metadata[key] == np.zeros(1, dtype=np.float32) |
0 commit comments