2222import subprocess # nosec B404
2323import tempfile
2424from glob import glob
25- from pathlib import Path
2625from typing import Any , Dict , List , Optional , Tuple , Union
2726from warnings import warn
2827
3635 PostProcessingConfigurationCallback ,
3736)
3837from omegaconf import DictConfig , ListConfig
39- from openvino .runtime import Core , serialize
4038from pytorch_lightning import Trainer
4139
4240from otx .algorithms .anomaly .adapters .anomalib .callbacks import (
4745from otx .algorithms .anomaly .adapters .anomalib .data import OTXAnomalyDataModule
4846from otx .algorithms .anomaly .adapters .anomalib .logger import get_logger
4947from otx .algorithms .anomaly .configs .base .configuration import BaseAnomalyConfig
48+ from otx .algorithms .common .utils import embed_ir_model_data
49+ from otx .algorithms .common .utils .utils import embed_onnx_model_data
5050from otx .api .entities .datasets import DatasetEntity
5151from otx .api .entities .inference_parameters import InferenceParameters
5252from otx .api .entities .metrics import NullPerformance , Performance , ScoreMetric
@@ -296,6 +296,8 @@ def export(
296296 self ._export_to_onnx (onnx_path )
297297
298298 if export_type == ExportType .ONNX :
299+ self ._add_metadata_to_ir (onnx_path , export_type )
300+
299301 with open (onnx_path , "rb" ) as file :
300302 output_model .set_data ("model.onnx" , file .read ())
301303 else :
@@ -306,7 +308,7 @@ def export(
306308 bin_file = glob (os .path .join (self .config .project .path , "*.bin" ))[0 ]
307309 xml_file = glob (os .path .join (self .config .project .path , "*.xml" ))[0 ]
308310
309- self ._add_metadata_to_ir (xml_file )
311+ self ._add_metadata_to_ir (xml_file , export_type )
310312
311313 with open (bin_file , "rb" ) as file :
312314 output_model .set_data ("openvino.bin" , file .read ())
@@ -319,40 +321,51 @@ def export(
319321 output_model .set_data ("label_schema.json" , label_schema_to_bytes (self .task_environment .label_schema ))
320322 self ._set_metadata (output_model )
321323
322- def _add_metadata_to_ir (self , xml_file : str ) -> None :
323- """Adds the metadata to the model IR.
324+ def _add_metadata_to_ir (self , model_file : str , export_type : ExportType ) -> None :
325+ """Adds the metadata to the model IR or ONNX .
324326
325327 Adds the metadata to the model IR. So that it can be used with the new modelAPI.
326328 This is because the metadata.json is not used by the new modelAPI.
327329 # TODO CVS-114640
328330 # TODO: Step 1. Remove metadata.json when modelAPI becomes the default inference method.
329- # TODO: Step 2. Remove this function when Anomalib is upgraded as the model graph will contain the required ops
331+ # TODO: Step 2. Update this function when Anomalib is upgraded as the model graph will contain the required ops
330332 # TODO: Step 3. Update modelAPI to remove pre/post-processing steps when Anomalib version is upgraded.
331333 """
332334 metadata = self ._get_metadata_dict ()
333- core = Core ()
334- model = core .read_model (xml_file )
335+ extra_model_data : Dict [Tuple [str , str ], Any ] = {}
335336 for key , value in metadata .items ():
336- if key == "transform" :
337+ if key in ( "transform" , "min" , "max" ) :
337338 continue
338- model . set_rt_info ( value , [ "model_info" , key ])
339+ extra_model_data [( "model_info" , key )] = value
339340 # Add transforms
340341 if "transform" in metadata :
341342 for transform_dict in metadata ["transform" ]["transform" ]["transforms" ]:
342343 transform = transform_dict .pop ("__class_fullname__" )
343344 if transform == "Normalize" :
344- model .set_rt_info (self ._serialize_list (transform_dict ["mean" ]), ["model_info" , "mean_values" ])
345- model .set_rt_info (self ._serialize_list (transform_dict ["std" ]), ["model_info" , "scale_values" ])
345+ extra_model_data [("model_info" , "mean_values" )] = self ._serialize_list (
346+ [x * 255.0 for x in transform_dict ["mean" ]]
347+ )
348+ extra_model_data [("model_info" , "scale_values" )] = self ._serialize_list (
349+ [x * 255.0 for x in transform_dict ["std" ]]
350+ )
346351 elif transform == "Resize" :
347- model . set_rt_info ( transform_dict [ "height" ], [ "model_info" , "orig_height" ])
348- model . set_rt_info ( transform_dict [ "width" ], [ "model_info" , "orig_width" ])
352+ extra_model_data [( "model_info" , "orig_height" )] = transform_dict [ "height" ]
353+ extra_model_data [( "model_info" , "orig_width" )] = transform_dict [ "width" ]
349354 else :
350355 warn (f"Transform { transform } is not supported currently" )
351- model .set_rt_info ("AnomalyDetection" , ["model_info" , "model_type" ])
352- tmp_xml_path = Path (Path (xml_file ).parent ) / "tmp.xml"
353- serialize (model , str (tmp_xml_path ))
354- tmp_xml_path .rename (xml_file )
355- Path (str (tmp_xml_path .parent / tmp_xml_path .stem ) + ".bin" ).unlink ()
356+ # Since we only need the diff of max and min, we fuse the min and max into one op
357+ if "min" in metadata and "max" in metadata :
358+ extra_model_data [("model_info" , "normalization_scale" )] = metadata ["max" ] - metadata ["min" ]
359+
360+ extra_model_data [("model_info" , "reverse_input_channels" )] = False
361+ extra_model_data [("model_info" , "model_type" )] = "AnomalyDetection"
362+ extra_model_data [("model_info" , "labels" )] = "Normal Anomaly"
363+ if export_type == ExportType .OPENVINO :
364+ embed_ir_model_data (model_file , extra_model_data )
365+ elif export_type == ExportType .ONNX :
366+ embed_onnx_model_data (model_file , extra_model_data )
367+ else :
368+ raise RuntimeError (f"not supported export type { export_type } " )
356369
357370 def _serialize_list (self , arr : Union [Tuple , List ]) -> str :
358371 """Converts a list to space separated string."""
0 commit comments