Skip to content

Commit 93d66ed

Browse files
ashwinvaidya17Ashwin Vaidya
andauthored
Update openvino export (#2305)
* Fix license in PR template * Add set_ir for modelAPI * Docstrings * Better serialization of transforms * Address PR comments --------- Co-authored-by: Ashwin Vaidya <[email protected]>
1 parent 6fb59db commit 93d66ed

File tree

1 file changed

+53
-2
lines changed

1 file changed

+53
-2
lines changed

src/otx/algorithms/anomaly/tasks/inference.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
import subprocess # nosec B404
2323
import tempfile
2424
from glob import glob
25-
from typing import Dict, List, Optional, Union
25+
from pathlib import Path
26+
from typing import Any, Dict, List, Optional, Tuple, Union
27+
from warnings import warn
2628

2729
import torch
2830
from anomalib.data.utils.transform import get_transforms
@@ -34,6 +36,7 @@
3436
PostProcessingConfigurationCallback,
3537
)
3638
from omegaconf import DictConfig, ListConfig
39+
from openvino.runtime import Core, serialize
3740
from pytorch_lightning import Trainer
3841

3942
from otx.algorithms.anomaly.adapters.anomalib.callbacks import (
@@ -302,6 +305,9 @@ def export(
302305
subprocess.run(optimize_command, check=True)
303306
bin_file = glob(os.path.join(self.config.project.path, "*.bin"))[0]
304307
xml_file = glob(os.path.join(self.config.project.path, "*.xml"))[0]
308+
309+
self._add_metadata_to_ir(xml_file)
310+
305311
with open(bin_file, "rb") as file:
306312
output_model.set_data("openvino.bin", file.read())
307313
with open(xml_file, "rb") as file:
@@ -313,6 +319,45 @@ def export(
313319
output_model.set_data("label_schema.json", label_schema_to_bytes(self.task_environment.label_schema))
314320
self._set_metadata(output_model)
315321

322+
def _add_metadata_to_ir(self, xml_file: str) -> None:
323+
"""Adds the metadata to the model IR.
324+
325+
Adds the metadata to the model IR. So that it can be used with the new modelAPI.
326+
This is because the metadata.json is not used by the new modelAPI.
327+
# TODO CVS-114640
328+
# TODO: Step 1. Remove metadata.json when modelAPI becomes the default inference method.
329+
# TODO: Step 2. Remove this function when Anomalib is upgraded as the model graph will contain the required ops
330+
# TODO: Step 3. Update modelAPI to remove pre/post-processing steps when Anomalib version is upgraded.
331+
"""
332+
metadata = self._get_metadata_dict()
333+
core = Core()
334+
model = core.read_model(xml_file)
335+
for key, value in metadata.items():
336+
if key == "transform":
337+
continue
338+
model.set_rt_info(value, ["model_info", key])
339+
# Add transforms
340+
if "transform" in metadata:
341+
for transform_dict in metadata["transform"]["transform"]["transforms"]:
342+
transform = transform_dict.pop("__class_fullname__")
343+
if transform == "Normalize":
344+
model.set_rt_info(self._serialize_list(transform_dict["mean"]), ["model_info", "mean_values"])
345+
model.set_rt_info(self._serialize_list(transform_dict["std"]), ["model_info", "scale_values"])
346+
elif transform == "Resize":
347+
model.set_rt_info(transform_dict["height"], ["model_info", "orig_height"])
348+
model.set_rt_info(transform_dict["width"], ["model_info", "orig_width"])
349+
else:
350+
warn(f"Transform {transform} is not supported currently")
351+
model.set_rt_info("AnomalyDetection", ["model_info", "model_type"])
352+
tmp_xml_path = Path(Path(xml_file).parent) / "tmp.xml"
353+
serialize(model, str(tmp_xml_path))
354+
tmp_xml_path.rename(xml_file)
355+
Path(str(tmp_xml_path.parent / tmp_xml_path.stem) + ".bin").unlink()
356+
357+
def _serialize_list(self, arr: Union[Tuple, List]) -> str:
358+
"""Converts a list to space separated string."""
359+
return " ".join(map(str, arr))
360+
316361
def model_info(self) -> Dict:
317362
"""Return model info to save the model weights.
318363
@@ -348,6 +393,12 @@ def save_model(self, output_model: ModelEntity) -> None:
348393
output_model.optimization_methods = self.optimization_methods
349394

350395
def _set_metadata(self, output_model: ModelEntity):
396+
"""Sets metadata in output_model."""
397+
metadata = self._get_metadata_dict()
398+
output_model.set_data("metadata", json.dumps(metadata).encode())
399+
400+
def _get_metadata_dict(self) -> Dict[str, Any]:
401+
"""Returns metadata dict."""
351402
image_threshold = (
352403
self.model.image_threshold.value.cpu().numpy().tolist() if hasattr(self.model, "image_threshold") else 0.5
353404
)
@@ -384,7 +435,7 @@ def _set_metadata(self, output_model: ModelEntity):
384435
metadata["max"] = max
385436
# Set the task type for inferencer
386437
metadata["task"] = str(self.task_type).lower().split("_")[-1]
387-
output_model.set_data("metadata", json.dumps(metadata).encode())
438+
return metadata
388439

389440
@staticmethod
390441
def _is_docker() -> bool:

0 commit comments

Comments
 (0)