@@ -339,7 +339,8 @@ def _form_prompt_chat_completions_api(
339339
340340 # Only return content directly if there's a single user message AND no tools
341341 if len (messages ) == 1 and messages [0 ].get ("role" ) == _USER_ROLE and (tools is None or len (tools ) == 0 ):
342- return output + str (messages [0 ]["content" ])
342+ first_msg = cast (dict [str , Any ], messages [0 ])
343+ return output + str (first_msg ["content" ])
343344
344345 # Warn if the last message is an assistant message with tool calls
345346 if messages and (messages [- 1 ].get ("role" ) == _ASSISTANT_ROLE or "tool_calls" in messages [- 1 ]):
@@ -358,22 +359,26 @@ def _form_prompt_chat_completions_api(
358359 if msg ["role" ] == _ASSISTANT_ROLE :
359360 output += _ASSISTANT_PREFIX
360361 # Handle content if present
361- if msg .get ("content" ):
362- output += f"{ msg ['content' ]} \n \n "
362+ content_value = cast (Optional [str ], msg .get ("content" ))
363+ if content_value :
364+ output += f"{ content_value } \n \n "
363365 # Handle tool calls if present
364366 if "tool_calls" in msg :
365367 for tool_call in msg ["tool_calls" ]:
366- call_id = tool_call ["id" ]
367- function_names [call_id ] = tool_call ["function" ]["name" ]
368- # Format function call as JSON within XML tags, now including call_id
369- function_call = {
370- "name" : tool_call ["function" ]["name" ],
371- "arguments" : json .loads (tool_call ["function" ]["arguments" ])
372- if tool_call ["function" ]["arguments" ]
373- else {},
374- "call_id" : call_id ,
375- }
376- output += f"{ _TOOL_CALL_TAG_START } \n { json .dumps (function_call , indent = 2 )} \n { _TOOL_CALL_TAG_END } \n \n "
368+ if tool_call ["type" ] == "function" :
369+ call_id = tool_call ["id" ]
370+ function_names [call_id ] = tool_call ["function" ]["name" ]
371+ # Format function call as JSON within XML tags, now including call_id
372+ function_call = {
373+ "name" : tool_call ["function" ]["name" ],
374+ "arguments" : json .loads (tool_call ["function" ]["arguments" ])
375+ if tool_call ["function" ]["arguments" ]
376+ else {},
377+ "call_id" : call_id ,
378+ }
379+ output += (
380+ f"{ _TOOL_CALL_TAG_START } \n { json .dumps (function_call , indent = 2 )} \n { _TOOL_CALL_TAG_END } \n \n "
381+ )
377382 elif msg ["role" ] == _TOOL_ROLE :
378383 # Handle tool responses
379384 output += _TOOL_PREFIX
@@ -506,13 +511,19 @@ def form_response_string_chat_completions_api(
506511 """
507512 response_dict = _response_to_dict (response )
508513 content = response_dict .get ("content" ) or ""
509- tool_calls = response_dict .get ("tool_calls" )
514+ tool_calls = cast ( Optional [ list [ dict [ str , Any ]]], response_dict .get ("tool_calls" ) )
510515 if tool_calls is not None :
511516 try :
512- tool_calls_str = "\n " .join (
513- f"{ _TOOL_CALL_TAG_START } \n { json .dumps ({'name' : call ['function' ]['name' ], 'arguments' : json .loads (call ['function' ]['arguments' ]) if call ['function' ]['arguments' ] else {}}, indent = 2 )} \n { _TOOL_CALL_TAG_END } "
514- for call in tool_calls
515- )
517+ rendered_calls : list [str ] = []
518+ for call in tool_calls :
519+ function_dict = call ["function" ]
520+ name = cast (str , function_dict ["name" ])
521+ args_str = cast (Optional [str ], function_dict .get ("arguments" ))
522+ args_obj = json .loads (args_str ) if args_str else {}
523+ rendered_calls .append (
524+ f"{ _TOOL_CALL_TAG_START } \n { json .dumps ({'name' : name , 'arguments' : args_obj }, indent = 2 )} \n { _TOOL_CALL_TAG_END } "
525+ )
526+ tool_calls_str = "\n " .join (rendered_calls )
516527 return f"{ content } \n { tool_calls_str } " .strip () if content else tool_calls_str
517528 except (KeyError , TypeError , json .JSONDecodeError ) as e :
518529 # Log the error but continue with just the content
0 commit comments