@@ -118,6 +118,12 @@ def __init__(
118118 self .chat_completion_parser = (
119119 chat_completion_parser or get_first_message_content
120120 )
121+ self .inference_parameters = [
122+ "maxTokens" ,
123+ "temperature" ,
124+ "topP" ,
125+ "stopSequences" ,
126+ ]
121127
122128 def init_sync_client (self ):
123129 """
@@ -192,6 +198,47 @@ def list_models(self):
192198 except Exception as e :
193199 print (f"Error listing models: { e } " )
194200
201+ def _validate_and_process_model_id (self , api_kwargs : Dict ):
202+ """
203+ Validate and process the model ID in API kwargs.
204+
205+ :param api_kwargs: Dictionary of API keyword arguments
206+ :raises KeyError: If 'model' key is missing
207+ """
208+ if "model" in api_kwargs :
209+ api_kwargs ["modelId" ] = api_kwargs .pop ("model" )
210+ else :
211+ raise KeyError ("The required key 'model' is missing in model_kwargs." )
212+
213+ return api_kwargs
214+
215+ def _separate_parameters (self , api_kwargs : Dict ) -> tuple :
216+ """
217+ Separate inference configuration and additional model request fields.
218+
219+ :param api_kwargs: Dictionary of API keyword arguments
220+ :return: Tuple of (inference_config, additional_model_request_fields)
221+ """
222+ inference_config = {}
223+ additional_model_request_fields = {}
224+ keys_to_remove = set ()
225+ excluded_keys = {"modelId" }
226+
227+ # Categorize parameters
228+ for key , value in list (api_kwargs .items ()):
229+ if key in self .inference_parameters :
230+ inference_config [key ] = value
231+ keys_to_remove .add (key )
232+ elif key not in excluded_keys :
233+ additional_model_request_fields [key ] = value
234+ keys_to_remove .add (key )
235+
236+ # Remove categorized keys from api_kwargs
237+ for key in keys_to_remove :
238+ api_kwargs .pop (key , None )
239+
240+ return api_kwargs , inference_config , additional_model_request_fields
241+
195242 def convert_inputs_to_api_kwargs (
196243 self ,
197244 input : Optional [Any ] = None ,
@@ -204,9 +251,17 @@ def convert_inputs_to_api_kwargs(
204251 """
205252 api_kwargs = model_kwargs .copy ()
206253 if model_type == ModelType .LLM :
254+ # Validate and process model ID
255+ api_kwargs = self ._validate_and_process_model_id (api_kwargs )
256+
257+ # Separate inference config and additional model request fields
258+ api_kwargs , inference_config , additional_model_request_fields = self ._separate_parameters (api_kwargs )
259+
207260 api_kwargs ["messages" ] = [
208261 {"role" : "user" , "content" : [{"text" : input }]},
209262 ]
263+ api_kwargs ["inferenceConfig" ] = inference_config
264+ api_kwargs ["additionalModelRequestFields" ] = additional_model_request_fields
210265 else :
211266 raise ValueError (f"Model type { model_type } not supported" )
212267 return api_kwargs
0 commit comments