33# SPDX-License-Identifier: Apache-2.0
44#
55
6- import abc
6+ from abc import ABC , abstractmethod
77from dataclasses import dataclass , field
8+ from typing import Any , Dict , List , Set , Tuple
89
910
1011@dataclass
@@ -17,30 +18,37 @@ class Metadata:
1718 meta : dict = field (default_factory = dict )
1819
1920
20- class InferenceAdapter (abc .ABC ):
21- """An abstract Model Adapter with the following interface:
22-
23- - Reading the model from disk or other place
24- - Loading the model to the device
25- - Accessing the information about inputs/outputs
26- - The model reshaping
27- - Synchronous model inference
28- - Asynchronous model inference
21+ class InferenceAdapter (ABC ):
22+ """
23+ An abstract Model Adapter with the following interface:
24+
25+ - Reading the model from disk or other place
26+ - Loading the model to the device
27+ - Accessing the information about inputs/outputs
28+ - The model reshaping
29+ - Synchronous model inference
30+ - Asynchronous model inference
2931 """
3032
3133 precisions = ("FP32" , "I32" , "FP16" , "I16" , "I8" , "U8" )
3234
33- @abc .abstractmethod
34- def __init__ (self ):
35- """An abstract Model Adapter constructor.
35+ @abstractmethod
36+ def __init__ (self ) -> None :
37+ """
38+ An abstract Model Adapter constructor.
3639 Reads the model from disk or other place.
3740 """
41+ self .model : Any
3842
39- @abc . abstractmethod
43+ @abstractmethod
4044 def load_model (self ):
4145 """Loads the model on the device."""
4246
43- @abc .abstractmethod
47+ @abstractmethod
48+ def get_model (self ):
49+ """Get the model."""
50+
51+ @abstractmethod
4452 def get_input_layers (self ):
4553 """Gets the names of model inputs and for each one creates the Metadata structure,
4654 which contains the information about the input shape, layout, precision
@@ -50,7 +58,7 @@ def get_input_layers(self):
5058 - the dict containing Metadata for all inputs
5159 """
5260
53- @abc . abstractmethod
61+ @abstractmethod
5462 def get_output_layers (self ):
5563 """Gets the names of model outputs and for each one creates the Metadata structure,
5664 which contains the information about the output shape, layout, precision
@@ -60,7 +68,7 @@ def get_output_layers(self):
6068 - the dict containing Metadata for all outputs
6169 """
6270
63- @abc . abstractmethod
71+ @abstractmethod
6472 def reshape_model (self , new_shape ):
6573 """Reshapes the model inputs to fit the new input shape.
6674
@@ -74,7 +82,7 @@ def reshape_model(self, new_shape):
7482 }
7583 """
7684
77- @abc . abstractmethod
85+ @abstractmethod
7886 def infer_sync (self , dict_data ):
7987 """Performs the synchronous model inference. The infer is a blocking method.
8088
@@ -95,9 +103,10 @@ def infer_sync(self, dict_data):
95103 }
96104 """
97105
98- @abc .abstractmethod
99- def infer_async (self , dict_data , callback_fn , callback_data ):
100- """Performs the asynchronous model inference and sets
106+ @abstractmethod
107+ def infer_async (self , dict_data , callback_data ):
108+ """
109+ Performs the asynchronous model inference and sets
101110 the callback for inference completion. Also, it should
102111 define get_raw_result() function, which handles the result
103112 of inference from the model.
@@ -109,11 +118,10 @@ def infer_async(self, dict_data, callback_fn, callback_data):
109118 'input_layer_name_2': data_2,
110119 ...
111120 }
112- - callback_fn: the callback function, which is defined outside the adapter
113121 - callback_data: the data for callback, that will be taken after the model inference is ended
114122 """
115123
116- @abc . abstractmethod
124+ @abstractmethod
117125 def is_ready (self ):
118126 """In case of asynchronous execution checks if one can submit input data
119127 to the model for inference, or all infer requests are busy.
@@ -123,23 +131,23 @@ def is_ready(self):
123131 submitted to the model for inference or not
124132 """
125133
126- @abc . abstractmethod
134+ @abstractmethod
127135 def await_all (self ):
128136 """In case of asynchronous execution waits the completion of all
129137 busy infer requests.
130138 """
131139
132- @abc . abstractmethod
140+ @abstractmethod
133141 def await_any (self ):
134142 """In case of asynchronous execution waits the completion of any
135143 busy infer request until it becomes available for the data submission.
136144 """
137145
138- @abc . abstractmethod
146+ @abstractmethod
139147 def get_rt_info (self , path ):
140148 """Forwards to openvino.Model.get_rt_info(path)"""
141149
142- @abc . abstractmethod
150+ @abstractmethod
143151 def embed_preprocessing (
144152 self ,
145153 layout ,
0 commit comments