2222import subprocess # nosec B404
2323import tempfile
2424from glob import glob
25- from typing import Dict , List , Optional , Union
25+ from pathlib import Path
26+ from typing import Any , Dict , List , Optional , Tuple , Union
27+ from warnings import warn
2628
2729import torch
2830from anomalib .data .utils .transform import get_transforms
3436 PostProcessingConfigurationCallback ,
3537)
3638from omegaconf import DictConfig , ListConfig
39+ from openvino .runtime import Core , serialize
3740from pytorch_lightning import Trainer
3841
3942from otx .algorithms .anomaly .adapters .anomalib .callbacks import (
@@ -302,6 +305,9 @@ def export(
302305 subprocess .run (optimize_command , check = True )
303306 bin_file = glob (os .path .join (self .config .project .path , "*.bin" ))[0 ]
304307 xml_file = glob (os .path .join (self .config .project .path , "*.xml" ))[0 ]
308+
309+ self ._add_metadata_to_ir (xml_file )
310+
305311 with open (bin_file , "rb" ) as file :
306312 output_model .set_data ("openvino.bin" , file .read ())
307313 with open (xml_file , "rb" ) as file :
@@ -313,6 +319,45 @@ def export(
313319 output_model .set_data ("label_schema.json" , label_schema_to_bytes (self .task_environment .label_schema ))
314320 self ._set_metadata (output_model )
315321
322+ def _add_metadata_to_ir (self , xml_file : str ) -> None :
323+ """Adds the metadata to the model IR.
324+
325+ Adds the metadata to the model IR. So that it can be used with the new modelAPI.
326+ This is because the metadata.json is not used by the new modelAPI.
327+ # TODO CVS-114640
328+ # TODO: Step 1. Remove metadata.json when modelAPI becomes the default inference method.
329+ # TODO: Step 2. Remove this function when Anomalib is upgraded as the model graph will contain the required ops
330+ # TODO: Step 3. Update modelAPI to remove pre/post-processing steps when Anomalib version is upgraded.
331+ """
332+ metadata = self ._get_metadata_dict ()
333+ core = Core ()
334+ model = core .read_model (xml_file )
335+ for key , value in metadata .items ():
336+ if key == "transform" :
337+ continue
338+ model .set_rt_info (value , ["model_info" , key ])
339+ # Add transforms
340+ if "transform" in metadata :
341+ for transform_dict in metadata ["transform" ]["transform" ]["transforms" ]:
342+ transform = transform_dict .pop ("__class_fullname__" )
343+ if transform == "Normalize" :
344+ model .set_rt_info (self ._serialize_list (transform_dict ["mean" ]), ["model_info" , "mean_values" ])
345+ model .set_rt_info (self ._serialize_list (transform_dict ["std" ]), ["model_info" , "scale_values" ])
346+ elif transform == "Resize" :
347+ model .set_rt_info (transform_dict ["height" ], ["model_info" , "orig_height" ])
348+ model .set_rt_info (transform_dict ["width" ], ["model_info" , "orig_width" ])
349+ else :
350+ warn (f"Transform { transform } is not supported currently" )
351+ model .set_rt_info ("AnomalyDetection" , ["model_info" , "model_type" ])
352+ tmp_xml_path = Path (Path (xml_file ).parent ) / "tmp.xml"
353+ serialize (model , str (tmp_xml_path ))
354+ tmp_xml_path .rename (xml_file )
355+ Path (str (tmp_xml_path .parent / tmp_xml_path .stem ) + ".bin" ).unlink ()
356+
357+ def _serialize_list (self , arr : Union [Tuple , List ]) -> str :
358+ """Converts a list to space separated string."""
359+ return " " .join (map (str , arr ))
360+
316361 def model_info (self ) -> Dict :
317362 """Return model info to save the model weights.
318363
@@ -348,6 +393,12 @@ def save_model(self, output_model: ModelEntity) -> None:
348393 output_model .optimization_methods = self .optimization_methods
349394
350395 def _set_metadata (self , output_model : ModelEntity ):
396+ """Sets metadata in output_model."""
397+ metadata = self ._get_metadata_dict ()
398+ output_model .set_data ("metadata" , json .dumps (metadata ).encode ())
399+
400+ def _get_metadata_dict (self ) -> Dict [str , Any ]:
401+ """Returns metadata dict."""
351402 image_threshold = (
352403 self .model .image_threshold .value .cpu ().numpy ().tolist () if hasattr (self .model , "image_threshold" ) else 0.5
353404 )
@@ -384,7 +435,7 @@ def _set_metadata(self, output_model: ModelEntity):
384435 metadata ["max" ] = max
385436 # Set the task type for inferencer
386437 metadata ["task" ] = str (self .task_type ).lower ().split ("_" )[- 1 ]
387- output_model . set_data ( "metadata" , json . dumps ( metadata ). encode ())
438+ return metadata
388439
389440 @staticmethod
390441 def _is_docker () -> bool :
0 commit comments