1212 "api_type": "cohere",
1313 "model": "command-r-plus",
1414 "api_key": os.environ.get("COHERE_API_KEY")
15+ "client_name": "autogen-cohere", # Optional parameter
1516 }
1617 ]}
1718
@@ -150,7 +151,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
150151 def create (self , params : Dict ) -> ChatCompletion :
151152
152153 messages = params .get ("messages" , [])
153-
154+ client_name = params . get ( "client_name" ) or "autogen-cohere"
154155 # Parse parameters to the Cohere API's parameters
155156 cohere_params = self .parse_params (params )
156157
@@ -162,7 +163,7 @@ def create(self, params: Dict) -> ChatCompletion:
162163 cohere_params ["preamble" ] = preamble
163164
164165 # We use chat model by default
165- client = Cohere (api_key = self .api_key )
166+ client = Cohere (api_key = self .api_key , client_name = client_name )
166167
167168 # Token counts will be returned
168169 prompt_tokens = 0
@@ -291,6 +292,23 @@ def create(self, params: Dict) -> ChatCompletion:
291292 return response_oai
292293
293294
295+ def extract_to_cohere_tool_results (tool_call_id : str , content_output : str , all_tool_calls ) -> List [Dict [str , Any ]]:
296+ temp_tool_results = []
297+
298+ for tool_call in all_tool_calls :
299+ if tool_call ["id" ] == tool_call_id :
300+
301+ call = {
302+ "name" : tool_call ["function" ]["name" ],
303+ "parameters" : json .loads (
304+ tool_call ["function" ]["arguments" ] if not tool_call ["function" ]["arguments" ] == "" else "{}"
305+ ),
306+ }
307+ output = [{"value" : content_output }]
308+ temp_tool_results .append (ToolResult (call = call , outputs = output ))
309+ return temp_tool_results
310+
311+
294312def oai_messages_to_cohere_messages (
295313 messages : list [Dict [str , Any ]], params : Dict [str , Any ], cohere_params : Dict [str , Any ]
296314) -> tuple [list [dict [str , Any ]], str , str ]:
@@ -358,7 +376,8 @@ def oai_messages_to_cohere_messages(
358376 # 'content' field renamed to 'message'
359377 # tools go into tools parameter
360378 # tool_results go into tool_results parameter
361- for message in messages :
379+ messages_length = len (messages )
380+ for index , message in enumerate (messages ):
362381
363382 if "role" in message and message ["role" ] == "system" :
364383 # System message
@@ -375,34 +394,34 @@ def oai_messages_to_cohere_messages(
375394 new_message = {
376395 "role" : "CHATBOT" ,
377396 "message" : message ["content" ],
378- # Not including tools in this message, may need to. Testing required.
397+ "tool_calls" : [
398+ {
399+ "name" : tool_call_ .get ("function" , {}).get ("name" ),
400+ "parameters" : json .loads (tool_call_ .get ("function" , {}).get ("arguments" ) or "null" ),
401+ }
402+ for tool_call_ in message ["tool_calls" ]
403+ ],
379404 }
380405
381406 cohere_messages .append (new_message )
382407 elif "role" in message and message ["role" ] == "tool" :
383- if "tool_call_id" in message :
384- # Convert the tool call to a result
408+ if not (tool_call_id := message .get ("tool_call_id" )):
409+ continue
410+
411+ # Convert the tool call to a result
412+ content_output = message ["content" ]
413+ tool_results_chat_turn = extract_to_cohere_tool_results (tool_call_id , content_output , tool_calls )
414+ if (index == messages_length - 1 ) or (messages [index + 1 ].get ("role" , "" ).lower () in ("user" , "tool" )):
415+ # If the tool call is the last message or the next message is a user/tool message, this is a recent tool call.
416+ # So, we pass it into tool_results.
417+ tool_results .extend (tool_results_chat_turn )
418+ continue
385419
386- tool_call_id = message ["tool_call_id" ]
387- content_output = message ["content" ]
388-
389- # Find the original tool
390- for tool_call in tool_calls :
391- if tool_call ["id" ] == tool_call_id :
392-
393- call = {
394- "name" : tool_call ["function" ]["name" ],
395- "parameters" : json .loads (
396- tool_call ["function" ]["arguments" ]
397- if not tool_call ["function" ]["arguments" ] == ""
398- else "{}"
399- ),
400- }
401- output = [{"value" : content_output }]
402-
403- tool_results .append (ToolResult (call = call , outputs = output ))
420+ else :
421+ # If its not the current tool call, we pass it as a tool message in the chat history.
422+ new_message = {"role" : "TOOL" , "tool_results" : tool_results_chat_turn }
423+ cohere_messages .append (new_message )
404424
405- break
406425 elif "content" in message and isinstance (message ["content" ], str ):
407426 # Standard text message
408427 new_message = {
@@ -422,7 +441,7 @@ def oai_messages_to_cohere_messages(
422441 # If we're adding tool_results, like we are, the last message can't be a USER message
423442 # So, we add a CHATBOT 'continue' message, if so.
424443 # Changed key from "content" to "message" (jaygdesai/autogen_Jay)
425- if cohere_messages [- 1 ]["role" ] == "USER " :
444+ if cohere_messages [- 1 ]["role" ]. lower () == "user " :
426445 cohere_messages .append ({"role" : "CHATBOT" , "message" : "Please continue." })
427446
428447 # We return a blank message when we have tool results
0 commit comments