Skip to content

Commit a860ea9

Browse files
author
David Eigen
committed
make modelinfo for inference thread-safe
1 parent 324491e commit a860ea9

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

clarifai/client/model.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -424,14 +424,14 @@ def predict(self,
424424
raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}."
425425
) # TODO Use Chunker for inputs len > 128
426426

427-
self._override_model_version(inference_params, output_config)
427+
model_info = self._get_model_info_for_inference(inference_params, output_config)
428428
request = service_pb2.PostModelOutputsRequest(
429429
user_app_id=self.user_app_id,
430430
model_id=self.id,
431431
version_id=self.model_version.id,
432432
inputs=inputs,
433433
runner_selector=runner_selector,
434-
model=self.model_info)
434+
model=model_info)
435435

436436
start_time = time.time()
437437
backoff_iterator = BackoffIterator(10)
@@ -704,14 +704,16 @@ def generate(self,
704704
raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}."
705705
) # TODO Use Chunker for inputs len > 128
706706

707-
self._override_model_version(inference_params, output_config)
707+
model_info = self._get_model_info_for_inference(inference_params, output_config)
708708
request = service_pb2.PostModelOutputsRequest(
709709
user_app_id=self.user_app_id,
710710
model_id=self.id,
711711
version_id=self.model_version.id,
712712
inputs=inputs,
713713
runner_selector=runner_selector,
714-
model=self.model_info)
714+
model=model_info)
715+
request.model.model_version.id = self.model_version.id
716+
request.model.model_version.params
715717

716718
start_time = time.time()
717719
backoff_iterator = BackoffIterator(10)
@@ -922,15 +924,15 @@ def generate_by_url(self,
922924
inference_params=inference_params,
923925
output_config=output_config)
924926

925-
def _req_iterator(self, input_iterator: Iterator[List[Input]], runner_selector: RunnerSelector):
927+
def _req_iterator(self, input_iterator: Iterator[List[Input]], runner_selector: RunnerSelector, model_info: resources_pb2.Model):
926928
for inputs in input_iterator:
927929
yield service_pb2.PostModelOutputsRequest(
928930
user_app_id=self.user_app_id,
929931
model_id=self.id,
930932
version_id=self.model_version.id,
931933
inputs=inputs,
932934
runner_selector=runner_selector,
933-
model=self.model_info)
935+
model=model_info)
934936

935937
def stream(self,
936938
inputs: Iterator[List[Input]],
@@ -954,8 +956,8 @@ def stream(self,
954956
# if not isinstance(inputs, Iterator[List[Input]]):
955957
# raise UserError('Invalid inputs, inputs must be a iterator of list of Input objects.')
956958

957-
self._override_model_version(inference_params, output_config)
958-
request = self._req_iterator(inputs, runner_selector)
959+
model_info = self._get_model_info_for_inference(inference_params, output_config)
960+
request = self._req_iterator(inputs, runner_selector, model_info)
959961

960962
start_time = time.time()
961963
backoff_iterator = BackoffIterator(10)
@@ -1168,7 +1170,7 @@ def input_generator():
11681170
inference_params=inference_params,
11691171
output_config=output_config)
11701172

1171-
def _override_model_version(self, inference_params: Dict = {}, output_config: Dict = {}) -> None:
1173+
def _get_model_info_for_inference(self, inference_params: Dict = {}, output_config: Dict = {}) -> None:
11721174
"""Overrides the model version.
11731175
11741176
Args:
@@ -1179,13 +1181,14 @@ def _override_model_version(self, inference_params: Dict = {}, output_config: Di
11791181
select_concepts (list[Concept]): The concepts to select.
11801182
sample_ms (int): The number of milliseconds to sample.
11811183
"""
1182-
params = Struct()
1183-
if inference_params is not None:
1184-
params.update(inference_params)
1185-
1186-
self.model_info.model_version.output_info.CopyFrom(
1187-
resources_pb2.OutputInfo(
1188-
output_config=resources_pb2.OutputConfig(**output_config), params=params))
1184+
if not inference_params and not output_config:
1185+
return self.model_info
1186+
1187+
model_info = resources_pb2.Model()
1188+
model_info.CopyFrom(self.model_info)
1189+
model_info.model_version.output_info.params.update(inference_params)
1190+
model_info.model_version.output_info.output_config.update(output_config)
1191+
return model_info
11891192

11901193
def _list_concepts(self) -> List[str]:
11911194
"""Lists all the concepts for the model type.

0 commit comments

Comments
 (0)