@@ -424,14 +424,14 @@ def predict(self,
424424 raise UserError (f"Too many inputs. Max is { MAX_MODEL_PREDICT_INPUTS } ."
425425 ) # TODO Use Chunker for inputs len > 128
426426
427- self ._override_model_version (inference_params , output_config )
427+ model_info = self ._get_model_info_for_inference (inference_params , output_config )
428428 request = service_pb2 .PostModelOutputsRequest (
429429 user_app_id = self .user_app_id ,
430430 model_id = self .id ,
431431 version_id = self .model_version .id ,
432432 inputs = inputs ,
433433 runner_selector = runner_selector ,
434- model = self . model_info )
434+ model = model_info )
435435
436436 start_time = time .time ()
437437 backoff_iterator = BackoffIterator (10 )
@@ -704,14 +704,16 @@ def generate(self,
704704 raise UserError (f"Too many inputs. Max is { MAX_MODEL_PREDICT_INPUTS } ."
705705 ) # TODO Use Chunker for inputs len > 128
706706
707- self ._override_model_version (inference_params , output_config )
707+ model_info = self ._get_model_info_for_inference (inference_params , output_config )
708708 request = service_pb2 .PostModelOutputsRequest (
709709 user_app_id = self .user_app_id ,
710710 model_id = self .id ,
711711 version_id = self .model_version .id ,
712712 inputs = inputs ,
713713 runner_selector = runner_selector ,
714- model = self .model_info )
714+ model = model_info )
715+ request .model .model_version .id = self .model_version .id
716+ request .model .model_version .params
715717
716718 start_time = time .time ()
717719 backoff_iterator = BackoffIterator (10 )
@@ -922,15 +924,15 @@ def generate_by_url(self,
922924 inference_params = inference_params ,
923925 output_config = output_config )
924926
925- def _req_iterator (self , input_iterator : Iterator [List [Input ]], runner_selector : RunnerSelector ):
927+ def _req_iterator (self , input_iterator : Iterator [List [Input ]], runner_selector : RunnerSelector , model_info : resources_pb2 . Model ):
926928 for inputs in input_iterator :
927929 yield service_pb2 .PostModelOutputsRequest (
928930 user_app_id = self .user_app_id ,
929931 model_id = self .id ,
930932 version_id = self .model_version .id ,
931933 inputs = inputs ,
932934 runner_selector = runner_selector ,
933- model = self . model_info )
935+ model = model_info )
934936
935937 def stream (self ,
936938 inputs : Iterator [List [Input ]],
@@ -954,8 +956,8 @@ def stream(self,
954956 # if not isinstance(inputs, Iterator[List[Input]]):
955957 # raise UserError('Invalid inputs, inputs must be a iterator of list of Input objects.')
956958
957- self ._override_model_version (inference_params , output_config )
958- request = self ._req_iterator (inputs , runner_selector )
959+ model_info = self ._get_model_info_for_inference (inference_params , output_config )
960+ request = self ._req_iterator (inputs , runner_selector , model_info )
959961
960962 start_time = time .time ()
961963 backoff_iterator = BackoffIterator (10 )
@@ -1168,7 +1170,7 @@ def input_generator():
11681170 inference_params = inference_params ,
11691171 output_config = output_config )
11701172
1171- def _override_model_version (self , inference_params : Dict = {}, output_config : Dict = {}) -> None :
1173+ def _get_model_info_for_inference (self , inference_params : Dict = {}, output_config : Dict = {}) -> None :
11721174 """Overrides the model version.
11731175
11741176 Args:
@@ -1179,13 +1181,14 @@ def _override_model_version(self, inference_params: Dict = {}, output_config: Di
11791181 select_concepts (list[Concept]): The concepts to select.
11801182 sample_ms (int): The number of milliseconds to sample.
11811183 """
1182- params = Struct ()
1183- if inference_params is not None :
1184- params .update (inference_params )
1185-
1186- self .model_info .model_version .output_info .CopyFrom (
1187- resources_pb2 .OutputInfo (
1188- output_config = resources_pb2 .OutputConfig (** output_config ), params = params ))
1184+ if not inference_params and not output_config :
1185+ return self .model_info
1186+
1187+ model_info = resources_pb2 .Model ()
1188+ model_info .CopyFrom (self .model_info )
1189+ model_info .model_version .output_info .params .update (inference_params )
1190+ model_info .model_version .output_info .output_config .update (output_config )
1191+ return model_info
11891192
11901193 def _list_concepts (self ) -> List [str ]:
11911194 """Lists all the concepts for the model type.
0 commit comments