@@ -313,7 +313,6 @@ def __init__(
313313 ** kwargs ,
314314 ):
315315 self .tools = kwargs .pop ("tools" , None )
316- self .tool_choice = kwargs .pop ("tool_choice" , None )
317316 super ().__init__ (
318317 model_name = model_name ,
319318 api_key = api_key ,
@@ -324,7 +323,9 @@ def __init__(
324323 )
325324 self .client = OpenAI (api_key = api_key )
326325
327- def _call_api (self , messages : list [Any | MessageBuilder ], ** kwargs ) -> dict :
326+ def _call_api (
327+ self , messages : list [Any | MessageBuilder ], tool_choice : str = "auto" , ** kwargs
328+ ) -> dict :
328329 input = []
329330 for msg in messages :
330331 input .extend (msg .prepare_message () if isinstance (msg , MessageBuilder ) else [msg ])
@@ -339,8 +340,10 @@ def _call_api(self, messages: list[Any | MessageBuilder], **kwargs) -> dict:
339340
340341 if self .tools is not None :
341342 api_params ["tools" ] = self .tools
342- if self .tool_choice is not None :
343- api_params ["tool_choice" ] = self .tool_choice
343+ if tool_choice in ("any" , "required" ):
344+ tool_choice = "required"
345+
346+ api_params ["tool_choice" ] = tool_choice
344347
345348 # api_params |= kwargs # Merge any additional parameters passed
346349 response = call_openai_api_with_retries (
@@ -388,7 +391,6 @@ def __init__(
388391 ):
389392
390393 self .tools = self .format_tools_for_chat_completion (kwargs .pop ("tools" , None ))
391- self .tool_choice = kwargs .pop ("tool_choice" , None )
392394
393395 super ().__init__ (
394396 model_name = model_name ,
@@ -403,7 +405,9 @@ def __init__(
403405 ** client_args
404406 ) # Ensures client_args is a dict or defaults to an empty dict
405407
406- def _call_api (self , messages : list [dict | MessageBuilder ]) -> openai .types .chat .ChatCompletion :
408+ def _call_api (
409+ self , messages : list [dict | MessageBuilder ], tool_choice : str = "auto"
410+ ) -> openai .types .chat .ChatCompletion :
407411 input = []
408412 for msg in messages :
409413 input .extend (msg .prepare_message () if isinstance (msg , MessageBuilder ) else [msg ])
@@ -416,8 +420,10 @@ def _call_api(self, messages: list[dict | MessageBuilder]) -> openai.types.chat.
416420 }
417421 if self .tools is not None :
418422 api_params ["tools" ] = self .tools
419- if self .tool_choice is not None :
420- api_params ["tool_choice" ] = self .tool_choice
423+
424+ if tool_choice in ("any" , "required" ):
425+ tool_choice = "required"
426+ api_params ["tool_choice" ] = tool_choice
421427
422428 response = call_openai_api_with_retries (self .client .chat .completions .create , api_params )
423429
@@ -517,7 +523,6 @@ def __init__(
517523 ** kwargs ,
518524 ):
519525 self .tools = kwargs .pop ("tools" , None )
520- self .tool_choice = kwargs .pop ("tool_choice" , None )
521526
522527 super ().__init__ (
523528 model_name = model_name ,
@@ -543,6 +548,9 @@ def _call_api(
543548 temp = self .apply_cache_breakpoints (msg , temp )
544549 input .extend (temp )
545550
551+ if tool_choice in ("any" , "required" ):
552+ tool_choice = "any" # Claude API expects "any" and gpt expects "required"
553+
546554 api_params : Dict [str , Any ] = {
547555 "model" : self .model_name ,
548556 "messages" : input ,
0 commit comments