Skip to content

Commit d9f895f

Browse files
Anirudh31415926535marklyszethinkall
authored andcommitted
fix: tool calling cohere (microsoft#3355)
* Add support for tool calling cohere * update tool calling code * make client name configurable with default * formatting nits * update docs --------- Co-authored-by: Mark Sze <66362098+marklysze@users.noreply.github.com> Co-authored-by: Li Jiang <bnujli@gmail.com>
1 parent 6998bae commit d9f895f

File tree

2 files changed

+48
-27
lines changed

2 files changed

+48
-27
lines changed

autogen/oai/cohere.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"api_type": "cohere",
1313
"model": "command-r-plus",
1414
"api_key": os.environ.get("COHERE_API_KEY")
15+
"client_name": "autogen-cohere", # Optional parameter
1516
}
1617
]}
1718
@@ -150,7 +151,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
150151
def create(self, params: Dict) -> ChatCompletion:
151152

152153
messages = params.get("messages", [])
153-
154+
client_name = params.get("client_name") or "autogen-cohere"
154155
# Parse parameters to the Cohere API's parameters
155156
cohere_params = self.parse_params(params)
156157

@@ -162,7 +163,7 @@ def create(self, params: Dict) -> ChatCompletion:
162163
cohere_params["preamble"] = preamble
163164

164165
# We use chat model by default
165-
client = Cohere(api_key=self.api_key)
166+
client = Cohere(api_key=self.api_key, client_name=client_name)
166167

167168
# Token counts will be returned
168169
prompt_tokens = 0
@@ -291,6 +292,23 @@ def create(self, params: Dict) -> ChatCompletion:
291292
return response_oai
292293

293294

295+
def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> List[Dict[str, Any]]:
296+
temp_tool_results = []
297+
298+
for tool_call in all_tool_calls:
299+
if tool_call["id"] == tool_call_id:
300+
301+
call = {
302+
"name": tool_call["function"]["name"],
303+
"parameters": json.loads(
304+
tool_call["function"]["arguments"] if not tool_call["function"]["arguments"] == "" else "{}"
305+
),
306+
}
307+
output = [{"value": content_output}]
308+
temp_tool_results.append(ToolResult(call=call, outputs=output))
309+
return temp_tool_results
310+
311+
294312
def oai_messages_to_cohere_messages(
295313
messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
296314
) -> tuple[list[dict[str, Any]], str, str]:
@@ -358,7 +376,8 @@ def oai_messages_to_cohere_messages(
358376
# 'content' field renamed to 'message'
359377
# tools go into tools parameter
360378
# tool_results go into tool_results parameter
361-
for message in messages:
379+
messages_length = len(messages)
380+
for index, message in enumerate(messages):
362381

363382
if "role" in message and message["role"] == "system":
364383
# System message
@@ -375,34 +394,34 @@ def oai_messages_to_cohere_messages(
375394
new_message = {
376395
"role": "CHATBOT",
377396
"message": message["content"],
378-
# Not including tools in this message, may need to. Testing required.
397+
"tool_calls": [
398+
{
399+
"name": tool_call_.get("function", {}).get("name"),
400+
"parameters": json.loads(tool_call_.get("function", {}).get("arguments") or "null"),
401+
}
402+
for tool_call_ in message["tool_calls"]
403+
],
379404
}
380405

381406
cohere_messages.append(new_message)
382407
elif "role" in message and message["role"] == "tool":
383-
if "tool_call_id" in message:
384-
# Convert the tool call to a result
408+
if not (tool_call_id := message.get("tool_call_id")):
409+
continue
410+
411+
# Convert the tool call to a result
412+
content_output = message["content"]
413+
tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls)
414+
if (index == messages_length - 1) or (messages[index + 1].get("role", "").lower() in ("user", "tool")):
415+
# If the tool call is the last message or the next message is a user/tool message, this is a recent tool call.
416+
# So, we pass it into tool_results.
417+
tool_results.extend(tool_results_chat_turn)
418+
continue
385419

386-
tool_call_id = message["tool_call_id"]
387-
content_output = message["content"]
388-
389-
# Find the original tool
390-
for tool_call in tool_calls:
391-
if tool_call["id"] == tool_call_id:
392-
393-
call = {
394-
"name": tool_call["function"]["name"],
395-
"parameters": json.loads(
396-
tool_call["function"]["arguments"]
397-
if not tool_call["function"]["arguments"] == ""
398-
else "{}"
399-
),
400-
}
401-
output = [{"value": content_output}]
402-
403-
tool_results.append(ToolResult(call=call, outputs=output))
420+
else:
421+
# If its not the current tool call, we pass it as a tool message in the chat history.
422+
new_message = {"role": "TOOL", "tool_results": tool_results_chat_turn}
423+
cohere_messages.append(new_message)
404424

405-
break
406425
elif "content" in message and isinstance(message["content"], str):
407426
# Standard text message
408427
new_message = {
@@ -422,7 +441,7 @@ def oai_messages_to_cohere_messages(
422441
# If we're adding tool_results, like we are, the last message can't be a USER message
423442
# So, we add a CHATBOT 'continue' message, if so.
424443
# Changed key from "content" to "message" (jaygdesai/autogen_Jay)
425-
if cohere_messages[-1]["role"] == "USER":
444+
if cohere_messages[-1]["role"].lower() == "user":
426445
cohere_messages.append({"role": "CHATBOT", "message": "Please continue."})
427446

428447
# We return a blank message when we have tool results

website/docs/topics/non-openai-models/cloud-cohere.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
"- seed (null, integer)\n",
101101
"- frequency_penalty (number 0..1)\n",
102102
"- presence_penalty (number 0..1)\n",
103+
"- client_name (null, string)\n",
103104
"\n",
104105
"Example:\n",
105106
"```python\n",
@@ -108,6 +109,7 @@
108109
" \"model\": \"command-r\",\n",
109110
" \"api_key\": \"your Cohere API Key goes here\",\n",
110111
" \"api_type\": \"cohere\",\n",
112+
" \"client_name\": \"autogen-cohere\",\n",
111113
" \"temperature\": 0.5,\n",
112114
" \"p\": 0.2,\n",
113115
" \"k\": 100,\n",
@@ -526,7 +528,7 @@
526528
"name": "python",
527529
"nbconvert_exporter": "python",
528530
"pygments_lexer": "ipython3",
529-
"version": "3.11.9"
531+
"version": "3.12.5"
530532
}
531533
},
532534
"nbformat": 4,

0 commit comments

Comments
 (0)