@@ -123,6 +123,12 @@ def __init__(
123123 self .chat_completion_parser = (
124124 chat_completion_parser or get_first_message_content
125125 )
126+ self .inference_parameters = [
127+ "maxTokens" ,
128+ "temperature" ,
129+ "topP" ,
130+ "stopSequences" ,
131+ ]
126132
127133 def init_sync_client (self ):
128134 """
@@ -233,6 +239,47 @@ def list_models(self):
233239 except Exception as e :
234240 print (f"Error listing models: { e } " )
235241
242+ def _validate_and_process_model_id (self , api_kwargs : Dict ):
243+ """
244+ Validate and process the model ID in API kwargs.
245+
246+ :param api_kwargs: Dictionary of API keyword arguments
247+ :raises KeyError: If 'model' key is missing
248+ """
249+ if "model" in api_kwargs :
250+ api_kwargs ["modelId" ] = api_kwargs .pop ("model" )
251+ else :
252+ raise KeyError ("The required key 'model' is missing in model_kwargs." )
253+
254+ return api_kwargs
255+
256+ def _separate_parameters (self , api_kwargs : Dict ) -> tuple :
257+ """
258+ Separate inference configuration and additional model request fields.
259+
260+ :param api_kwargs: Dictionary of API keyword arguments
261+ :return: Tuple of (inference_config, additional_model_request_fields)
262+ """
263+ inference_config = {}
264+ additional_model_request_fields = {}
265+ keys_to_remove = set ()
266+ excluded_keys = {"modelId" }
267+
268+ # Categorize parameters
269+ for key , value in list (api_kwargs .items ()):
270+ if key in self .inference_parameters :
271+ inference_config [key ] = value
272+ keys_to_remove .add (key )
273+ elif key not in excluded_keys :
274+ additional_model_request_fields [key ] = value
275+ keys_to_remove .add (key )
276+
277+ # Remove categorized keys from api_kwargs
278+ for key in keys_to_remove :
279+ api_kwargs .pop (key , None )
280+
281+ return api_kwargs , inference_config , additional_model_request_fields
282+
236283 def convert_inputs_to_api_kwargs (
237284 self ,
238285 input : Optional [Any ] = None ,
@@ -245,9 +292,17 @@ def convert_inputs_to_api_kwargs(
245292 """
246293 api_kwargs = model_kwargs .copy ()
247294 if model_type == ModelType .LLM :
295+ # Validate and process model ID
296+ api_kwargs = self ._validate_and_process_model_id (api_kwargs )
297+
298+ # Separate inference config and additional model request fields
299+ api_kwargs , inference_config , additional_model_request_fields = self ._separate_parameters (api_kwargs )
300+
248301 api_kwargs ["messages" ] = [
249302 {"role" : "user" , "content" : [{"text" : input }]},
250303 ]
304+ api_kwargs ["inferenceConfig" ] = inference_config
305+ api_kwargs ["additionalModelRequestFields" ] = additional_model_request_fields
251306 else :
252307 raise ValueError (f"Model type { model_type } not supported" )
253308 return api_kwargs
0 commit comments