1919import numpy as np
2020
2121from .inference_adapter import InferenceAdapter , Metadata
22- from .utils import Layout
22+ from .utils import Layout , get_rt_info_from_dict
2323
2424
2525class OVMSAdapter (InferenceAdapter ):
@@ -29,62 +29,65 @@ class OVMSAdapter(InferenceAdapter):
2929
3030 def __init__ (self , target_model : str ):
3131 """Expected format: <address>:<port>/models/<model_name>[:<model_version>]"""
32- import ovmsclient
32+ import tritonclient . http as httpclient
3333
3434 service_url , self .model_name , self .model_version = _parse_model_arg (
3535 target_model
3636 )
37- self .client = ovmsclient .make_grpc_client (url = service_url )
38- _verify_model_available (self .client , self .model_name , self .model_version )
37+ self .client = httpclient .InferenceServerClient (service_url )
38+ if not self .client .is_model_ready (self .model_name , self .model_version ):
39+ raise RuntimeError (
40+ f"Requested model: { self .model_name } , version: { self .model_version } is not accessible"
41+ )
3942
4043 self .metadata = self .client .get_model_metadata (
4144 model_name = self .model_name , model_version = self .model_version
4245 )
46+ self .inputs = self .get_input_layers ()
4347
4448 def get_input_layers (self ):
4549 return {
46- name : Metadata (
47- {name },
50+ meta [ " name" ] : Metadata (
51+ {meta [ " name" ] },
4852 meta ["shape" ],
4953 Layout .from_shape (meta ["shape" ]),
50- _tf2ov_precision . get ( meta ["dtype" ], meta [ "dtype" ]) ,
54+ meta ["datatype" ] ,
5155 )
52- for name , meta in self .metadata ["inputs" ]. items ()
56+ for meta in self .metadata ["inputs" ]
5357 }
5458
5559 def get_output_layers (self ):
5660 return {
57- name : Metadata (
58- {name },
61+ meta [ " name" ] : Metadata (
62+ {meta [ " name" ] },
5963 shape = meta ["shape" ],
60- precision = _tf2ov_precision . get ( meta ["dtype" ], meta [ "dtype" ]) ,
64+ precision = meta ["datatype" ] ,
6165 )
62- for name , meta in self .metadata ["outputs" ]. items ()
66+ for meta in self .metadata ["outputs" ]
6367 }
6468
6569 def infer_sync (self , dict_data ):
66- inputs = _prepare_inputs (dict_data , self .metadata [ " inputs" ] )
67- raw_result = self .client .predict (
68- inputs , model_name = self .model_name , model_version = self .model_version
70+ inputs = _prepare_inputs (dict_data , self .inputs )
71+ raw_result = self .client .infer (
72+ model_name = self .model_name , model_version = self .model_version , inputs = inputs
6973 )
70- # For models with single output ovmsclient returns ndarray with results,
71- # so the dict must be created to correctly implement interface.
72- if isinstance ( raw_result , np . ndarray ) :
73- output_name = next ( iter (( self . metadata [ "outputs" ]. keys ())) )
74- return { output_name : raw_result }
75- return raw_result
74+
75+ inference_results = {}
76+ for output in self . metadata [ "outputs" ] :
77+ inference_results [ output [ "name" ]] = raw_result . as_numpy ( output [ "name" ] )
78+
79+ return inference_results
7680
7781 def infer_async (self , dict_data , callback_data ):
78- inputs = _prepare_inputs (dict_data , self .metadata [ " inputs" ] )
79- raw_result = self .client .predict (
80- inputs , model_name = self .model_name , model_version = self .model_version
82+ inputs = _prepare_inputs (dict_data , self .inputs )
83+ raw_result = self .client .infer (
84+ model_name = self .model_name , model_version = self .model_version , inputs = inputs
8185 )
82- # For models with single output ovmsclient returns ndarray with results,
83- # so the dict must be created to correctly implement interface.
84- if isinstance (raw_result , np .ndarray ):
85- output_name = list (self .metadata ["outputs" ].keys ())[0 ]
86- raw_result = {output_name : raw_result }
87- self .callback_fn (raw_result , (lambda x : x , callback_data ))
86+ inference_results = {}
87+ for output in self .metadata ["outputs" ]:
88+ inference_results [output ["name" ]] = raw_result .as_numpy (output ["name" ])
89+
90+ self .callback_fn (inference_results , (lambda x : x , callback_data ))
8891
8992 def set_callback (self , callback_fn ):
9093 self .callback_fn = callback_fn
@@ -120,32 +123,19 @@ def reshape_model(self, new_shape):
120123 raise NotImplementedError
121124
122125 def get_rt_info (self , path ):
123- raise NotImplementedError ("OVMSAdapter does not support RT info getting" )
124-
125-
126- _tf2ov_precision = {
127- "DT_INT64" : "I64" ,
128- "DT_UINT64" : "U64" ,
129- "DT_FLOAT" : "FP32" ,
130- "DT_UINT32" : "U32" ,
131- "DT_INT32" : "I32" ,
132- "DT_HALF" : "FP16" ,
133- "DT_INT16" : "I16" ,
134- "DT_INT8" : "I8" ,
135- "DT_UINT8" : "U8" ,
136- }
137-
138-
139- _tf2np_precision = {
140- "DT_INT64" : np .int64 ,
141- "DT_UINT64" : np .uint64 ,
142- "DT_FLOAT" : np .float32 ,
143- "DT_UINT32" : np .uint32 ,
144- "DT_INT32" : np .int32 ,
145- "DT_HALF" : np .float16 ,
146- "DT_INT16" : np .int16 ,
147- "DT_INT8" : np .int8 ,
148- "DT_UINT8" : np .uint8 ,
126+ return get_rt_info_from_dict (self .metadata ["rt_info" ], path )
127+
128+
129+ _triton2np_precision = {
130+ "INT64" : np .int64 ,
131+ "UINT64" : np .uint64 ,
132+ "FLOAT" : np .float32 ,
133+ "UINT32" : np .uint32 ,
134+ "INT32" : np .int32 ,
135+ "HALF" : np .float16 ,
136+ "INT16" : np .int16 ,
137+ "INT8" : np .int8 ,
138+ "UINT8" : np .uint8 ,
149139}
150140
151141
@@ -161,40 +151,29 @@ def _parse_model_arg(target_model: str):
161151 model_spec = model .split (":" )
162152 if len (model_spec ) == 1 :
163153 # model version not specified - use latest
164- return service_url , model_spec [0 ], 0
154+ return service_url , model_spec [0 ], ""
165155 if len (model_spec ) == 2 :
166- return service_url , model_spec [0 ], int ( model_spec [1 ])
156+ return service_url , model_spec [0 ], model_spec [1 ]
167157 raise ValueError ("invalid target_model format" )
168158
169159
170- def _verify_model_available (client , model_name , model_version ):
171- import ovmsclient
172-
173- version = "latest" if model_version == 0 else model_version
174- try :
175- model_status = client .get_model_status (model_name , model_version )
176- except ovmsclient .ModelNotFoundError as e :
177- raise RuntimeError (
178- f"Requested model: { model_name } , version: { version } has not been found"
179- ) from e
180- target_version = max (model_status .keys ())
181- version_status = model_status [target_version ]
182- if version_status ["state" ] != "AVAILABLE" or version_status ["error_code" ] != 0 :
183- raise RuntimeError (
184- f"Requested model: { model_name } , version: { version } is not in available state"
185- )
186-
187-
188160def _prepare_inputs (dict_data , inputs_meta ):
189- inputs = {}
161+ import tritonclient .http as httpclient
162+
163+ inputs = []
190164 for input_name , input_data in dict_data .items ():
191165 if input_name not in inputs_meta .keys ():
192166 raise ValueError ("Input data does not match model inputs" )
193167 input_info = inputs_meta [input_name ]
194- model_precision = _tf2np_precision [input_info [ "dtype" ] ]
168+ model_precision = _triton2np_precision [input_info . precision ]
195169 if isinstance (input_data , np .ndarray ) and input_data .dtype != model_precision :
196170 input_data = input_data .astype (model_precision )
197171 elif isinstance (input_data , list ):
198172 input_data = np .array (input_data , dtype = model_precision )
199- inputs [input_name ] = input_data
173+
174+ infer_input = httpclient .InferInput (
175+ input_name , input_data .shape , input_info .precision
176+ )
177+ infer_input .set_data_from_numpy (input_data )
178+ inputs .append (infer_input )
200179 return inputs
0 commit comments