1010from contextlib import contextmanager
1111from typing import TYPE_CHECKING , Any , NoReturn , Type
1212
13+ from model_api .adapters .inference_adapter import InferenceAdapter
1314from model_api .adapters .onnx_adapter import ONNXRuntimeAdapter
1415from model_api .adapters .openvino_adapter import (
1516 OpenvinoAdapter ,
2324
2425 from numpy import ndarray
2526
26- from model_api .adapters .inference_adapter import InferenceAdapter
27-
2827
2928class WrapperError (Exception ):
3029 """The class for errors occurred in Model API wrappers"""
@@ -100,11 +99,7 @@ def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {}
10099 self .callback_fn = lambda _ : None
101100
102101 def get_model (self ):
103- model = self .inference_adapter .get_model ()
104- model .set_rt_info (self .__model__ , ["model_info" , "model_type" ])
105- for name in self .parameters ():
106- model .set_rt_info (getattr (self , name ), ["model_info" , name ])
107- return model
102+ return self .inference_adapter .get_model ()
108103
109104 @classmethod
110105 def get_model_class (cls , name : str ) -> Type :
@@ -122,7 +117,7 @@ def get_model_class(cls, name: str) -> Type:
122117 @classmethod
123118 def create_model (
124119 cls ,
125- model : str ,
120+ model : str | InferenceAdapter ,
126121 model_type : Any | None = None ,
127122 configuration : dict [str , Any ] = {},
128123 preload : bool = True ,
@@ -140,7 +135,7 @@ def create_model(
140135 """Create an instance of the Model API model
141136
142137 Args:
143- model (str): model name from OpenVINO Model Zoo, path to model, OVMS URL
138+ model (str| InferenceAdapter ): model name from OpenVINO Model Zoo, path to model, OVMS URL, or an adapter
144139 configuration (:obj:`dict`, optional): dictionary of model config with model properties, for example
145140 confidence_threshold, labels
146141 model_type (:obj:`str`, optional): name of model wrapper to create (e.g. "ssd")
@@ -162,7 +157,9 @@ def create_model(
162157 Model object
163158 """
164159 inference_adapter : InferenceAdapter
165- if isinstance (model , str ) and re .compile (
160+ if isinstance (model , InferenceAdapter ):
161+ inference_adapter = model
162+ elif isinstance (model , str ) and re .compile (
166163 r"(\w+\.*\-*)*\w+:\d+\/models\/[a-zA-Z0-9._-]+(\:\d+)*" ,
167164 ).fullmatch (model ):
168165 inference_adapter = OVMSAdapter (model )
@@ -487,7 +484,12 @@ def log_layers_info(self):
487484 f"precision: { metadata .precision } , layout: { metadata .layout } " ,
488485 )
489486
490- def save (self , xml_path , bin_path = "" , version = "UNSPECIFIED" ):
491- import openvino
487+ def save (self , path : str , weights_path : str = "" , version : str = "UNSPECIFIED" ):
488+ model_info = {
489+ "model_type" : self .__model__ ,
490+ }
491+ for name in self .parameters ():
492+ model_info [name ] = getattr (self , name )
492493
493- openvino .serialize (self .get_model (), xml_path , bin_path , version )
494+ self .inference_adapter .update_model_info (model_info )
495+ self .inference_adapter .save_model (path , weights_path , version )
0 commit comments