Skip to content

Commit 8417e81

Browse files
committed
Switch to tritonclient in OVMS adapter
1 parent 5a06e0c commit 8417e81

File tree

3 files changed

+60
-84
lines changed

3 files changed

+60
-84
lines changed

model_api/python/model_api/adapters/ovms_adapter.py

Lines changed: 58 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import numpy as np
2020

2121
from .inference_adapter import InferenceAdapter, Metadata
22-
from .utils import Layout
22+
from .utils import Layout, get_rt_info_from_dict
2323

2424

2525
class OVMSAdapter(InferenceAdapter):
@@ -29,62 +29,65 @@ class OVMSAdapter(InferenceAdapter):
2929

3030
def __init__(self, target_model: str):
3131
"""Expected format: <address>:<port>/models/<model_name>[:<model_version>]"""
32-
import ovmsclient
32+
import tritonclient.http as httpclient
3333

3434
service_url, self.model_name, self.model_version = _parse_model_arg(
3535
target_model
3636
)
37-
self.client = ovmsclient.make_grpc_client(url=service_url)
38-
_verify_model_available(self.client, self.model_name, self.model_version)
37+
self.client = httpclient.InferenceServerClient(service_url)
38+
if not self.client.is_model_ready(self.model_name, self.model_version):
39+
raise RuntimeError(
40+
f"Requested model: {self.model_name}, version: {self.model_version} is not accessible"
41+
)
3942

4043
self.metadata = self.client.get_model_metadata(
4144
model_name=self.model_name, model_version=self.model_version
4245
)
46+
self.inputs = self.get_input_layers()
4347

4448
def get_input_layers(self):
4549
return {
46-
name: Metadata(
47-
{name},
50+
meta["name"]: Metadata(
51+
{meta["name"]},
4852
meta["shape"],
4953
Layout.from_shape(meta["shape"]),
50-
_tf2ov_precision.get(meta["dtype"], meta["dtype"]),
54+
meta["datatype"],
5155
)
52-
for name, meta in self.metadata["inputs"].items()
56+
for meta in self.metadata["inputs"]
5357
}
5458

5559
def get_output_layers(self):
5660
return {
57-
name: Metadata(
58-
{name},
61+
meta["name"]: Metadata(
62+
{meta["name"]},
5963
shape=meta["shape"],
60-
precision=_tf2ov_precision.get(meta["dtype"], meta["dtype"]),
64+
precision=meta["datatype"],
6165
)
62-
for name, meta in self.metadata["outputs"].items()
66+
for meta in self.metadata["outputs"]
6367
}
6468

6569
def infer_sync(self, dict_data):
66-
inputs = _prepare_inputs(dict_data, self.metadata["inputs"])
67-
raw_result = self.client.predict(
68-
inputs, model_name=self.model_name, model_version=self.model_version
70+
inputs = _prepare_inputs(dict_data, self.inputs)
71+
raw_result = self.client.infer(
72+
model_name=self.model_name, model_version=self.model_version, inputs=inputs
6973
)
70-
# For models with single output ovmsclient returns ndarray with results,
71-
# so the dict must be created to correctly implement interface.
72-
if isinstance(raw_result, np.ndarray):
73-
output_name = next(iter((self.metadata["outputs"].keys())))
74-
return {output_name: raw_result}
75-
return raw_result
74+
75+
inference_results = {}
76+
for output in self.metadata["outputs"]:
77+
inference_results[output["name"]] = raw_result.as_numpy(output["name"])
78+
79+
return inference_results
7680

7781
def infer_async(self, dict_data, callback_data):
78-
inputs = _prepare_inputs(dict_data, self.metadata["inputs"])
79-
raw_result = self.client.predict(
80-
inputs, model_name=self.model_name, model_version=self.model_version
82+
inputs = _prepare_inputs(dict_data, self.inputs)
83+
raw_result = self.client.infer(
84+
model_name=self.model_name, model_version=self.model_version, inputs=inputs
8185
)
82-
# For models with single output ovmsclient returns ndarray with results,
83-
# so the dict must be created to correctly implement interface.
84-
if isinstance(raw_result, np.ndarray):
85-
output_name = list(self.metadata["outputs"].keys())[0]
86-
raw_result = {output_name: raw_result}
87-
self.callback_fn(raw_result, (lambda x: x, callback_data))
86+
inference_results = {}
87+
for output in self.metadata["outputs"]:
88+
inference_results[output["name"]] = raw_result.as_numpy(output["name"])
89+
90+
self.callback_fn(inference_results, (lambda x: x, callback_data))
8891

8992
def set_callback(self, callback_fn):
9093
self.callback_fn = callback_fn
@@ -120,32 +123,19 @@ def reshape_model(self, new_shape):
120123
raise NotImplementedError
121124

122125
def get_rt_info(self, path):
123-
raise NotImplementedError("OVMSAdapter does not support RT info getting")
124-
125-
126-
_tf2ov_precision = {
127-
"DT_INT64": "I64",
128-
"DT_UINT64": "U64",
129-
"DT_FLOAT": "FP32",
130-
"DT_UINT32": "U32",
131-
"DT_INT32": "I32",
132-
"DT_HALF": "FP16",
133-
"DT_INT16": "I16",
134-
"DT_INT8": "I8",
135-
"DT_UINT8": "U8",
136-
}
137-
138-
139-
_tf2np_precision = {
140-
"DT_INT64": np.int64,
141-
"DT_UINT64": np.uint64,
142-
"DT_FLOAT": np.float32,
143-
"DT_UINT32": np.uint32,
144-
"DT_INT32": np.int32,
145-
"DT_HALF": np.float16,
146-
"DT_INT16": np.int16,
147-
"DT_INT8": np.int8,
148-
"DT_UINT8": np.uint8,
126+
return get_rt_info_from_dict(self.metadata["rt_info"], path)
127+
128+
129+
_triton2np_precision = {
130+
"INT64": np.int64,
131+
"UINT64": np.uint64,
132+
"FLOAT": np.float32,
133+
"UINT32": np.uint32,
134+
"INT32": np.int32,
135+
"HALF": np.float16,
136+
"INT16": np.int16,
137+
"INT8": np.int8,
138+
"UINT8": np.uint8,
149139
}
150140

151141

@@ -161,40 +151,29 @@ def _parse_model_arg(target_model: str):
161151
model_spec = model.split(":")
162152
if len(model_spec) == 1:
163153
# model version not specified - use latest
164-
return service_url, model_spec[0], 0
154+
return service_url, model_spec[0], ""
165155
if len(model_spec) == 2:
166-
return service_url, model_spec[0], int(model_spec[1])
156+
return service_url, model_spec[0], model_spec[1]
167157
raise ValueError("invalid target_model format")
168158

169159

170-
def _verify_model_available(client, model_name, model_version):
171-
import ovmsclient
172-
173-
version = "latest" if model_version == 0 else model_version
174-
try:
175-
model_status = client.get_model_status(model_name, model_version)
176-
except ovmsclient.ModelNotFoundError as e:
177-
raise RuntimeError(
178-
f"Requested model: {model_name}, version: {version} has not been found"
179-
) from e
180-
target_version = max(model_status.keys())
181-
version_status = model_status[target_version]
182-
if version_status["state"] != "AVAILABLE" or version_status["error_code"] != 0:
183-
raise RuntimeError(
184-
f"Requested model: {model_name}, version: {version} is not in available state"
185-
)
186-
187-
188160
def _prepare_inputs(dict_data, inputs_meta):
189-
inputs = {}
161+
import tritonclient.http as httpclient
162+
163+
inputs = []
190164
for input_name, input_data in dict_data.items():
191165
if input_name not in inputs_meta.keys():
192166
raise ValueError("Input data does not match model inputs")
193167
input_info = inputs_meta[input_name]
194-
model_precision = _tf2np_precision[input_info["dtype"]]
168+
model_precision = _triton2np_precision[input_info.precision]
195169
if isinstance(input_data, np.ndarray) and input_data.dtype != model_precision:
196170
input_data = input_data.astype(model_precision)
197171
elif isinstance(input_data, list):
198172
input_data = np.array(input_data, dtype=model_precision)
199-
inputs[input_name] = input_data
173+
174+
infer_input = httpclient.InferInput(
175+
input_name, input_data.shape, input_info.precision
176+
)
177+
infer_input.set_data_from_numpy(input_data)
178+
inputs.append(infer_input)
200179
return inputs

model_api/python/model_api/models/model.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,7 @@ def _load_config(self, config):
268268
"Cannot get runtime attribute. Path to runtime attribute is incorrect."
269269
in str(error)
270270
)
271-
is_OVMSAdapter = (
272-
str(error) == "OVMSAdapter does not support RT info getting"
273-
)
274-
if not missing_rt_info and not is_OVMSAdapter:
271+
if not missing_rt_info:
275272
raise
276273

277274
for name, value in config.items():

model_api/python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ dependencies = [
3333

3434
[project.optional-dependencies]
3535
ovms = [
36-
"ovmsclient",
36+
"tritonclient[http]",
3737
]
3838
tests = [
3939
"httpx",

0 commit comments

Comments
 (0)