|
11 | 11 | import numpy as np |
12 | 12 |
|
13 | 13 | from .inference_adapter import InferenceAdapter, Metadata |
14 | | -from .utils import Layout |
| 14 | +from .utils import Layout, get_rt_info_from_dict |
15 | 15 |
|
16 | 16 |
|
17 | 17 | class OVMSAdapter(InferenceAdapter): |
18 | | - """Class that allows working with models served by the OpenVINO Model Server""" |
| 18 | + """Inference adapter that allows working with models served by the OpenVINO Model Server""" |
19 | 19 |
|
20 | 20 | def __init__(self, target_model: str): |
21 | | - """Expected format: <address>:<port>/models/<model_name>[:<model_version>]""" |
22 | | - import ovmsclient |
| 21 | + """ |
| 22 | + Initializes OVMS adapter. |
| 23 | +
|
| 24 | + Args: |
| 25 | + target_model (str): Model URL. Expected format: <address>:<port>/v2/models/<model_name>[:<model_version>] |
| 26 | + """ |
| 27 | + import tritonclient.http as httpclient |
23 | 28 |
|
24 | 29 | service_url, self.model_name, self.model_version = _parse_model_arg( |
25 | 30 | target_model, |
26 | 31 | ) |
27 | | - self.client = ovmsclient.make_grpc_client(url=service_url) |
28 | | - _verify_model_available(self.client, self.model_name, self.model_version) |
| 32 | + self.client = httpclient.InferenceServerClient(service_url) |
| 33 | + if not self.client.is_model_ready(self.model_name, self.model_version): |
| 34 | + msg = f"Requested model: {self.model_name}, version: {self.model_version} is not accessible" |
| 35 | + raise RuntimeError(msg) |
29 | 36 |
|
30 | 37 | self.metadata = self.client.get_model_metadata( |
31 | 38 | model_name=self.model_name, |
32 | 39 | model_version=self.model_version, |
33 | 40 | ) |
| 41 | + self.inputs = self.get_input_layers() |
| 42 | + |
| 43 | + def get_input_layers(self) -> dict[str, Metadata]: |
| 44 | + """ |
| 45 | + Retrieves information about remote model's inputs. |
34 | 46 |
|
35 | | - def get_input_layers(self): |
| 47 | + Returns: |
| 48 | + dict[str, Metadata]: metadata for each input. |
| 49 | + """ |
36 | 50 | return { |
37 | | - name: Metadata( |
38 | | - {name}, |
| 51 | + meta["name"]: Metadata( |
| 52 | + {meta["name"]}, |
39 | 53 | meta["shape"], |
40 | 54 | Layout.from_shape(meta["shape"]), |
41 | | - _tf2ov_precision.get(meta["dtype"], meta["dtype"]), |
| 55 | + meta["datatype"], |
42 | 56 | ) |
43 | | - for name, meta in self.metadata["inputs"].items() |
| 57 | + for meta in self.metadata["inputs"] |
44 | 58 | } |
45 | 59 |
|
46 | | - def get_output_layers(self): |
| 60 | + def get_output_layers(self) -> dict[str, Metadata]: |
| 61 | + """ |
| 62 | + Retrieves information about remote model's outputs. |
| 63 | +
|
| 64 | + Returns: |
| 65 | + dict[str, Metadata]: metadata for each output. |
| 66 | + """ |
47 | 67 | return { |
48 | | - name: Metadata( |
49 | | - {name}, |
| 68 | + meta["name"]: Metadata( |
| 69 | + {meta["name"]}, |
50 | 70 | shape=meta["shape"], |
51 | | - precision=_tf2ov_precision.get(meta["dtype"], meta["dtype"]), |
| 71 | + precision=meta["datatype"], |
52 | 72 | ) |
53 | | - for name, meta in self.metadata["outputs"].items() |
| 73 | + for meta in self.metadata["outputs"] |
54 | 74 | } |
55 | 75 |
|
56 | | - def infer_sync(self, dict_data): |
57 | | - inputs = _prepare_inputs(dict_data, self.metadata["inputs"]) |
58 | | - raw_result = self.client.predict( |
59 | | - inputs, |
| 76 | + def infer_sync(self, dict_data: dict) -> dict: |
| 77 | + """ |
| 78 | + Performs the synchronous model inference. The infer is a blocking method. |
| 79 | +
|
| 80 | + Args: |
| 81 | + dict_data (dict): data for each input layer. |
| 82 | +
|
| 83 | + Returns: |
| 84 | + dict: model raw outputs. |
| 85 | + """ |
| 86 | + inputs = _prepare_inputs(dict_data, self.inputs) |
| 87 | + raw_result = self.client.infer( |
60 | 88 | model_name=self.model_name, |
61 | 89 | model_version=self.model_version, |
| 90 | + inputs=inputs, |
62 | 91 | ) |
63 | | - # For models with single output ovmsclient returns ndarray with results, |
64 | | - # so the dict must be created to correctly implement interface. |
65 | | - if isinstance(raw_result, np.ndarray): |
66 | | - output_name = next(iter(self.metadata["outputs"].keys())) |
67 | | - return {output_name: raw_result} |
68 | | - return raw_result |
69 | | - |
70 | | - def infer_async(self, dict_data, callback_data): |
71 | | - inputs = _prepare_inputs(dict_data, self.metadata["inputs"]) |
72 | | - raw_result = self.client.predict( |
73 | | - inputs, |
| 92 | + |
| 93 | + inference_results = {} |
| 94 | + for output in self.metadata["outputs"]: |
| 95 | + inference_results[output["name"]] = raw_result.as_numpy(output["name"]) |
| 96 | + |
| 97 | + return inference_results |
| 98 | + |
| 99 | + def infer_async(self, dict_data: dict, callback_data: Any): |
| 100 | + """A stub method imitating async inference with a blocking call.""" |
| 101 | + inputs = _prepare_inputs(dict_data, self.inputs) |
| 102 | + raw_result = self.client.infer( |
74 | 103 | model_name=self.model_name, |
75 | 104 | model_version=self.model_version, |
| 105 | + inputs=inputs, |
76 | 106 | ) |
77 | | - # For models with single output ovmsclient returns ndarray with results, |
78 | | - # so the dict must be created to correctly implement interface. |
79 | | - if isinstance(raw_result, np.ndarray): |
80 | | - output_name = next(iter(self.metadata["outputs"].keys())) |
81 | | - raw_result = {output_name: raw_result} |
82 | | - self.callback_fn(raw_result, (lambda x: x, callback_data)) |
| 107 | + inference_results = {} |
| 108 | + for output in self.metadata["outputs"]: |
| 109 | + inference_results[output["name"]] = raw_result.as_numpy(output["name"]) |
| 110 | + |
| 111 | + self.callback_fn(inference_results, (lambda x: x, callback_data)) |
83 | 112 |
|
84 | 113 | def set_callback(self, callback_fn: Callable): |
85 | 114 | self.callback_fn = callback_fn |
@@ -118,97 +147,84 @@ def embed_preprocessing( |
118 | 147 | ): |
119 | 148 | pass |
120 | 149 |
|
121 | | - def reshape_model(self, new_shape): |
122 | | - raise NotImplementedError |
123 | | - |
124 | | - def get_rt_info(self, path): |
125 | | - msg = "OVMSAdapter does not support RT info getting" |
| 150 | + def reshape_model(self, new_shape: dict): |
| 151 | + """OVMS adapter can not modify the remote model. This method raises an exception.""" |
| 152 | + msg = "OVMSAdapter does not support model reshaping" |
126 | 153 | raise NotImplementedError(msg) |
127 | 154 |
|
| 155 | + def get_rt_info(self, path: list[str]) -> Any: |
| 156 | + """Returns an attribute stored in model info.""" |
| 157 | + return get_rt_info_from_dict(self.metadata["rt_info"], path) |
| 158 | + |
128 | 159 | def update_model_info(self, model_info: dict[str, Any]): |
| 160 | + """OVMS adapter can not update the source model info. This method raises an exception.""" |
129 | 161 | msg = "OVMSAdapter does not support updating model info" |
130 | 162 | raise NotImplementedError(msg) |
131 | 163 |
|
132 | 164 | def save_model(self, path: str, weights_path: str | None = None, version: str | None = None): |
| 165 | + """OVMS adapter can not retrieve the source model. This method raises an exception.""" |
133 | 166 | msg = "OVMSAdapter does not support saving a model" |
134 | 167 | raise NotImplementedError(msg) |
135 | 168 |
|
136 | 169 |
|
137 | | -_tf2ov_precision = { |
138 | | - "DT_INT64": "I64", |
139 | | - "DT_UINT64": "U64", |
140 | | - "DT_FLOAT": "FP32", |
141 | | - "DT_UINT32": "U32", |
142 | | - "DT_INT32": "I32", |
143 | | - "DT_HALF": "FP16", |
144 | | - "DT_INT16": "I16", |
145 | | - "DT_INT8": "I8", |
146 | | - "DT_UINT8": "U8", |
147 | | -} |
148 | | - |
149 | | - |
150 | | -_tf2np_precision = { |
151 | | - "DT_INT64": np.int64, |
152 | | - "DT_UINT64": np.uint64, |
153 | | - "DT_FLOAT": np.float32, |
154 | | - "DT_UINT32": np.uint32, |
155 | | - "DT_INT32": np.int32, |
156 | | - "DT_HALF": np.float16, |
157 | | - "DT_INT16": np.int16, |
158 | | - "DT_INT8": np.int8, |
159 | | - "DT_UINT8": np.uint8, |
| 170 | +_triton2np_precision = { |
| 171 | + "INT64": np.int64, |
| 172 | + "UINT64": np.uint64, |
| 173 | + "FLOAT": np.float32, |
| 174 | + "UINT32": np.uint32, |
| 175 | + "INT32": np.int32, |
| 176 | + "HALF": np.float16, |
| 177 | + "INT16": np.int16, |
| 178 | + "INT8": np.int8, |
| 179 | + "UINT8": np.uint8, |
| 180 | + "FP32": np.float32, |
160 | 181 | } |
161 | 182 |
|
162 | 183 |
|
163 | 184 | def _parse_model_arg(target_model: str): |
| 185 | + """Parses OVMS model URL.""" |
164 | 186 | if not isinstance(target_model, str): |
165 | 187 | msg = "target_model must be str" |
166 | 188 | raise TypeError(msg) |
167 | 189 | # Expected format: <address>:<port>/models/<model_name>[:<model_version>] |
168 | 190 | if not re.fullmatch( |
169 | | - r"(\w+\.*\-*)*\w+:\d+\/models\/[a-zA-Z0-9._-]+(\:\d+)*", |
| 191 | + r"(\w+\.*\-*)*\w+:\d+\/v2/models\/[a-zA-Z0-9._-]+(\:\d+)*", |
170 | 192 | target_model, |
171 | 193 | ): |
172 | 194 | msg = "invalid --model option format" |
173 | 195 | raise ValueError(msg) |
174 | | - service_url, _, model = target_model.split("/") |
| 196 | + service_url, _, _, model = target_model.split("/") |
175 | 197 | model_spec = model.split(":") |
176 | 198 | if len(model_spec) == 1: |
177 | 199 | # model version not specified - use latest |
178 | | - return service_url, model_spec[0], 0 |
| 200 | + return service_url, model_spec[0], "" |
179 | 201 | if len(model_spec) == 2: |
180 | | - return service_url, model_spec[0], int(model_spec[1]) |
181 | | - msg = "invalid target_model format" |
| 202 | + return service_url, model_spec[0], model_spec[1] |
| 203 | + msg = "Invalid target_model format" |
182 | 204 | raise ValueError(msg) |
183 | 205 |
|
184 | 206 |
|
185 | | -def _verify_model_available(client, model_name, model_version): |
186 | | - import ovmsclient |
| 207 | +def _prepare_inputs(dict_data: dict, inputs_meta: dict[str, Metadata]): |
| 208 | + """Converts raw model inputs into OVMS-specific representation.""" |
| 209 | + import tritonclient.http as httpclient |
187 | 210 |
|
188 | | - version = "latest" if model_version == 0 else model_version |
189 | | - try: |
190 | | - model_status = client.get_model_status(model_name, model_version) |
191 | | - except ovmsclient.ModelNotFoundError as e: |
192 | | - msg = f"Requested model: {model_name}, version: {version} has not been found" |
193 | | - raise RuntimeError(msg) from e |
194 | | - target_version = max(model_status.keys()) |
195 | | - version_status = model_status[target_version] |
196 | | - if version_status["state"] != "AVAILABLE" or version_status["error_code"] != 0: |
197 | | - msg = f"Requested model: {model_name}, version: {version} is not in available state" |
198 | | - raise RuntimeError(msg) |
199 | | - |
200 | | - |
201 | | -def _prepare_inputs(dict_data, inputs_meta): |
202 | | - inputs = {} |
| 211 | + inputs = [] |
203 | 212 | for input_name, input_data in dict_data.items(): |
204 | 213 | if input_name not in inputs_meta: |
205 | 214 | msg = "Input data does not match model inputs" |
206 | 215 | raise ValueError(msg) |
207 | 216 | input_info = inputs_meta[input_name] |
208 | | - model_precision = _tf2np_precision[input_info["dtype"]] |
| 217 | + model_precision = _triton2np_precision[input_info.precision] |
209 | 218 | if isinstance(input_data, np.ndarray) and input_data.dtype != model_precision: |
210 | 219 | input_data = input_data.astype(model_precision) |
211 | 220 | elif isinstance(input_data, list): |
212 | 221 | input_data = np.array(input_data, dtype=model_precision) |
213 | | - inputs[input_name] = input_data |
| 222 | + |
| 223 | + infer_input = httpclient.InferInput( |
| 224 | + input_name, |
| 225 | + input_data.shape, |
| 226 | + input_info.precision, |
| 227 | + ) |
| 228 | + infer_input.set_data_from_numpy(input_data) |
| 229 | + inputs.append(infer_input) |
214 | 230 | return inputs |
0 commit comments