@@ -52,6 +52,88 @@ def parse_triton_config_pbtxt(pbtxt_path) -> ModelConfig:
5252 raise ValueError (f"Failed to parse config file { pbtxt_path } " ) from e
5353
5454
55+ class TritonRemoteModel :
56+ """A remote model that is hosted on a Triton Inference Server.
57+
58+ Args:
59+ model_name (str): The name of the model.
60+ netloc (str): The network location of the Triton Inference Server.
61+ model_config (ModelConfig): The model config.
62+ headers (dict): The headers to send to the Triton Inference Server.
63+ """
64+
65+ def __init__ (self , model_name , netloc , model_config , headers = None , ** kwargs ):
66+ self ._headers = headers
67+ self ._request_compression_algorithm = None
68+ self ._response_compression_algorithm = None
69+ self ._model_name = model_name
70+ self ._model_version = None
71+ self ._model_config = model_config
72+ self ._request_compression_algorithm = None
73+ self ._response_compression_algorithm = None
74+ self ._count = 0
75+
76+ try :
77+ self ._triton_client = httpclient .InferenceServerClient (url = netloc , verbose = kwargs .get ("verbose" , False ))
78+ logging .info (f"Created triton client: { self ._triton_client } " )
79+ except Exception as e :
80+ logging .error ("channel creation failed: " + str (e ))
81+ raise
82+
83+ def __call__ (self , data , ** kwds ):
84+
85+ self ._count += 1
86+ logging .info (f"{ self .__class__ .__name__ } .__call__: { self ._model_name } count: { self ._count } " )
87+
88+ inputs = []
89+ outputs = []
90+
91+ # For now support only one input and one output
92+ input_name = self ._model_config .input [0 ].name
93+ input_type = str .split (DataType .Name (self ._model_config .input [0 ].data_type ), "_" )[1 ] # remove the prefix
94+ input_shape = list (self ._model_config .input [0 ].dims )
95+ data_shape = list (data .shape )
96+ logging .info (f"Model config input data shape: { input_shape } " )
97+ logging .info (f"Actual input data shape: { data_shape } " )
98+
99+ # The server side will handle the batching, and with dynamic batching
100+ # the model config does not have the batch size in the input dims.
101+ logging .info (f"Effective input_name: { input_name } , input_type: { input_type } , input_shape: { data_shape } " )
102+
103+ inputs .append (httpclient .InferInput (input_name , data_shape , input_type ))
104+
105+ # Move to tensor to CPU
106+ input0_data_np = data .detach ().cpu ().numpy ()
107+ logging .debug (f"Input data shape: { input0_data_np .shape } " )
108+
109+ # Initialize the data
110+ inputs [0 ].set_data_from_numpy (input0_data_np , binary_data = False )
111+
112+ output_name = self ._model_config .output [0 ].name
113+ outputs .append (httpclient .InferRequestedOutput (output_name , binary_data = True ))
114+
115+ query_params = {f"{ self ._model_name } _count" : self ._count }
116+ results = self ._triton_client .infer (
117+ self ._model_name ,
118+ inputs ,
119+ outputs = outputs ,
120+ query_params = query_params ,
121+ headers = self ._headers ,
122+ request_compression_algorithm = self ._request_compression_algorithm ,
123+ response_compression_algorithm = self ._response_compression_algorithm ,
124+ )
125+
126+ logging .info (f"Got results{ results .get_response ()} " )
127+ output0_data = results .as_numpy (output_name )
128+ logging .debug (f"as_numpy output0_data.shape: { output0_data .shape } " )
129+ logging .debug (f"as_numpy output0_data.dtype: { output0_data .dtype } " )
130+
131+ # Convert numpy array to torch tensor as expected by the anticipated clients,
132+ # e.g. monai cliding window inference
133+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
134+ return torch .as_tensor (output0_data ).to (device ) # from_numpy is fine too.
135+
136+
55137class TritonModel (Model ):
56138 """Represents Triton models in the model repository.
57139
@@ -124,9 +206,7 @@ def __init__(self, path: str, name: str = ""):
124206 f"Model name in config.pbtxt ({ self ._model_config .name } ) does not match the folder name ({ self ._name } )."
125207 )
126208
127- self ._netloc = None # network location of the Triton Inference Server
128- self ._predictor = None # triton remote client
129-
209+ self ._netloc : str = ""
130210 logging .info (f"Created Triton model: { self ._name } " )
131211
132212 def connect (self , netloc : str , ** kwargs ):
@@ -137,36 +217,51 @@ def connect(self, netloc: str, **kwargs):
137217 """
138218
139219 if not netloc :
140- if not self ._predictor :
141- raise ValueError ("Network location is required to connect to the Triton Inference Server." )
142- else :
143- logging .warning ("No network location provided, using the last connected network location." )
220+ raise ValueError ("Network location is required to connect to the Triton Inference Server." )
144221
145- if self ._predictor and not self ._netloc .casefold () == netloc .casefold ():
222+ if self ._netloc and not self ._netloc .casefold () == netloc .casefold ():
146223 logging .warning (f"Reconnecting to a different Triton Inference Server at { netloc } from { self ._netloc } ." )
147224
225+ self ._predictor = TritonRemoteModel (self ._name , netloc , self ._model_config , ** kwargs )
148226 self ._netloc = netloc
149- self . _predictor = TritonRemoteModel ( self . _name , self . _netloc , self . _model_config , ** kwargs )
227+
150228 return self ._predictor
151229
152230 @property
153231 def model_config (self ):
154- if not self ._model_config : # not expected to happen with the current implementation.
155- self ._model_config = parse_triton_config_pbtxt (self ._model_path / "config.pbtxt" )
156232 return self ._model_config
157233
158234 @property
159- def predictor (self ):
160- """Get the model's predictor (triton remote client)
235+ def net_loc (self ):
236+ """Get the network location of the Triton Inference Server, i.e. "<host>:<port>".
161237
162238 Returns:
163- the model's predictor
239+ str: The network location of the Triton Inference Server.
164240 """
165- if self ._predictor is None :
166- raise ValueError ("Model is not connected to the Triton Inference Server." )
167241
242+ return self ._netloc
243+
244+ @net_loc .setter
245+ def net_loc (self , value : str ):
246+ """Set the network location of the Triton Inference Server, and causes re-connect."""
247+ if not value :
248+ raise ValueError ("Network location cannot be empty." )
249+ self ._netloc = value
250+ # Reconnect to the Triton Inference Server at the new network location.
251+ self .connect (value )
252+
253+ @property
254+ def predictor (self ):
255+ if not self ._predictor :
256+ raise ValueError ("Model is not connected to the Triton Inference Server." )
168257 return self ._predictor
169258
259+ @predictor .setter
260+ def predictor (self , predictor : TritonRemoteModel ):
261+ if not isinstance (predictor , TritonRemoteModel ):
262+ raise ValueError ("Predictor must be an instance of TritonRemoteModel." )
263+ self ._predictor = predictor
264+
170265 @classmethod
171266 def accept (cls , path : str ) -> tuple [bool , str ]:
172267 model_folder : Path = Path (path )
@@ -195,85 +290,3 @@ def accept(cls, path: str) -> tuple[bool, str]:
195290 logging .info (f"Model { model_folder .name } only has config.pbtxt in client workspace." )
196291
197292 return True , cls .model_type
198-
199-
200- class TritonRemoteModel :
201- """A remote model that is hosted on a Triton Inference Server.
202-
203- Args:
204- model_name (str): The name of the model.
205- netloc (str): The network location of the Triton Inference Server.
206- model_config (ModelConfig): The model config.
207- headers (dict): The headers to send to the Triton Inference Server.
208- """
209-
210- def __init__ (self , model_name , netloc , model_config , headers = None , ** kwargs ):
211- self ._headers = headers
212- self ._request_compression_algorithm = None
213- self ._response_compression_algorithm = None
214- self ._model_name = model_name
215- self ._model_version = None
216- self ._model_config = model_config
217- self ._request_compression_algorithm = None
218- self ._response_compression_algorithm = None
219- self ._count = 0
220-
221- try :
222- self ._triton_client = httpclient .InferenceServerClient (url = netloc , verbose = kwargs .get ("verbose" , False ))
223- print (f"Created triton client: { self ._triton_client } " )
224- except Exception as e :
225- logging .error ("channel creation failed: " + str (e ))
226- raise
227-
228- def __call__ (self , data , ** kwds ):
229-
230- self ._count += 1
231- logging .info (f"{ self .__class__ .__name__ } .__call__: { self ._model_name } count: { self ._count } " )
232-
233- inputs = []
234- outputs = []
235-
236- # For now support only one input and one output
237- input_name = self ._model_config .input [0 ].name
238- input_type = str .split (DataType .Name (self ._model_config .input [0 ].data_type ), "_" )[1 ] # remove the prefix
239- input_shape = list (self ._model_config .input [0 ].dims )
240- data_shape = list (data .shape )
241- logging .info (f"Model config input data shape: { input_shape } " )
242- logging .info (f"Actual input data shape: { data_shape } " )
243-
244- # The server side will handle the batching, and with dynamic batching
245- # the model config does not have the batch size in the input dims.
246- logging .info (f"Effective input_name: { input_name } , input_type: { input_type } , input_shape: { data_shape } " )
247-
248- inputs .append (httpclient .InferInput (input_name , data_shape , input_type ))
249-
250- # Move to tensor to CPU
251- input0_data_np = data .detach ().cpu ().numpy ()
252- logging .debug (f"Input data shape: { input0_data_np .shape } " )
253-
254- # Initialize the data
255- inputs [0 ].set_data_from_numpy (input0_data_np , binary_data = False )
256-
257- output_name = self ._model_config .output [0 ].name
258- outputs .append (httpclient .InferRequestedOutput (output_name , binary_data = True ))
259-
260- query_params = {f"{ self ._model_name } _count" : self ._count }
261- results = self ._triton_client .infer (
262- self ._model_name ,
263- inputs ,
264- outputs = outputs ,
265- query_params = query_params ,
266- headers = self ._headers ,
267- request_compression_algorithm = self ._request_compression_algorithm ,
268- response_compression_algorithm = self ._response_compression_algorithm ,
269- )
270-
271- logging .info (f"Got results{ results .get_response ()} " )
272- output0_data = results .as_numpy (output_name )
273- logging .debug (f"as_numpy output0_data.shape: { output0_data .shape } " )
274- logging .debug (f"as_numpy output0_data.dtype: { output0_data .dtype } " )
275-
276- # Convert numpy array to torch tensor as expected by the anticipated clients,
277- # e.g. monai cliding window inference
278- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
279- return torch .as_tensor (output0_data ).to (device ) # from_numpy is fine too.
0 commit comments