Skip to content

Commit 331ea8d

Browse files
rafidkadmontagu
andauthored
Add support for Cohere models (#203)
Co-authored-by: David Montague <[email protected]>
1 parent b8d7136 commit 331ea8d

File tree

8 files changed

+739
-10
lines changed

8 files changed

+739
-10
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,22 @@
8181
'anthropic:claude-3-5-haiku-latest',
8282
'anthropic:claude-3-5-sonnet-latest',
8383
'anthropic:claude-3-opus-latest',
84+
'claude-3-5-haiku-latest',
85+
'claude-3-5-sonnet-latest',
86+
'claude-3-opus-latest',
87+
'cohere:c4ai-aya-expanse-32b',
88+
'cohere:c4ai-aya-expanse-8b',
89+
'cohere:command',
90+
'cohere:command-light',
91+
'cohere:command-light-nightly',
92+
'cohere:command-nightly',
93+
'cohere:command-r',
94+
'cohere:command-r-03-2024',
95+
'cohere:command-r-08-2024',
96+
'cohere:command-r-plus',
97+
'cohere:command-r-plus-04-2024',
98+
'cohere:command-r-plus-08-2024',
99+
'cohere:command-r7b-12-2024',
84100
'test',
85101
]
86102
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -228,6 +244,10 @@ def infer_model(model: Model | KnownModelName) -> Model:
228244
from .test import TestModel
229245

230246
return TestModel()
247+
elif model.startswith('cohere:'):
248+
from .cohere import CohereModel
249+
250+
return CohereModel(model[7:])
231251
elif model.startswith('openai:'):
232252
from .openai import OpenAIModel
233253

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
from __future__ import annotations as _annotations
2+
3+
from collections.abc import Iterable
4+
from dataclasses import dataclass, field
5+
from itertools import chain
6+
from typing import Literal, TypeAlias, Union
7+
8+
from cohere import TextAssistantMessageContentItem
9+
from typing_extensions import assert_never
10+
11+
from .. import result
12+
from .._utils import guard_tool_call_id as _guard_tool_call_id
13+
from ..messages import (
14+
ModelMessage,
15+
ModelRequest,
16+
ModelResponse,
17+
ModelResponsePart,
18+
RetryPromptPart,
19+
SystemPromptPart,
20+
TextPart,
21+
ToolCallPart,
22+
ToolReturnPart,
23+
UserPromptPart,
24+
)
25+
from ..settings import ModelSettings
26+
from ..tools import ToolDefinition
27+
from . import (
28+
AgentModel,
29+
Model,
30+
check_allow_model_requests,
31+
)
32+
33+
try:
34+
from cohere import (
35+
AssistantChatMessageV2,
36+
AsyncClientV2,
37+
ChatMessageV2,
38+
ChatResponse,
39+
SystemChatMessageV2,
40+
ToolCallV2,
41+
ToolCallV2Function,
42+
ToolChatMessageV2,
43+
ToolV2,
44+
ToolV2Function,
45+
UserChatMessageV2,
46+
)
47+
from cohere.v2.client import OMIT
48+
except ImportError as _import_error:
49+
raise ImportError(
50+
'Please install `cohere` to use the Cohere model, '
51+
"you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
52+
) from _import_error
53+
54+
CohereModelName: TypeAlias = Union[
55+
str,
56+
Literal[
57+
'c4ai-aya-expanse-32b',
58+
'c4ai-aya-expanse-8b',
59+
'command',
60+
'command-light',
61+
'command-light-nightly',
62+
'command-nightly',
63+
'command-r',
64+
'command-r-03-2024',
65+
'command-r-08-2024',
66+
'command-r-plus',
67+
'command-r-plus-04-2024',
68+
'command-r-plus-08-2024',
69+
'command-r7b-12-2024',
70+
],
71+
]
72+
73+
74+
@dataclass(init=False)
75+
class CohereModel(Model):
76+
"""A model that uses the Cohere API.
77+
78+
Internally, this uses the [Cohere Python client](
79+
https://github.com/cohere-ai/cohere-python) to interact with the API.
80+
81+
Apart from `__init__`, all methods are private or match those of the base class.
82+
"""
83+
84+
model_name: CohereModelName
85+
client: AsyncClientV2 = field(repr=False)
86+
87+
def __init__(
88+
self,
89+
model_name: CohereModelName,
90+
*,
91+
api_key: str | None = None,
92+
cohere_client: AsyncClientV2 | None = None,
93+
):
94+
"""Initialize an Cohere model.
95+
96+
Args:
97+
model_name: The name of the Cohere model to use. List of model names
98+
available [here](https://docs.cohere.com/docs/models#command).
99+
api_key: The API key to use for authentication, if not provided, the
100+
`COHERE_API_KEY` environment variable will be used if available.
101+
cohere_client: An existing Cohere async client to use. If provided,
102+
`api_key` must be `None`.
103+
"""
104+
self.model_name: CohereModelName = model_name
105+
if cohere_client is not None:
106+
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
107+
self.client = cohere_client
108+
else:
109+
self.client = AsyncClientV2(api_key=api_key) # type: ignore
110+
111+
async def agent_model(
112+
self,
113+
*,
114+
function_tools: list[ToolDefinition],
115+
allow_text_result: bool,
116+
result_tools: list[ToolDefinition],
117+
) -> AgentModel:
118+
check_allow_model_requests()
119+
tools = [self._map_tool_definition(r) for r in function_tools]
120+
if result_tools:
121+
tools += [self._map_tool_definition(r) for r in result_tools]
122+
return CohereAgentModel(
123+
self.client,
124+
self.model_name,
125+
allow_text_result,
126+
tools,
127+
)
128+
129+
def name(self) -> str:
130+
return f'cohere:{self.model_name}'
131+
132+
@staticmethod
133+
def _map_tool_definition(f: ToolDefinition) -> ToolV2:
134+
return ToolV2(
135+
type='function',
136+
function=ToolV2Function(
137+
name=f.name,
138+
description=f.description,
139+
parameters=f.parameters_json_schema,
140+
),
141+
)
142+
143+
144+
@dataclass
145+
class CohereAgentModel(AgentModel):
146+
"""Implementation of `AgentModel` for Cohere models."""
147+
148+
client: AsyncClientV2
149+
model_name: CohereModelName
150+
allow_text_result: bool
151+
tools: list[ToolV2]
152+
153+
async def request(
154+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
155+
) -> tuple[ModelResponse, result.Usage]:
156+
response = await self._chat(messages, model_settings)
157+
return self._process_response(response), _map_usage(response)
158+
159+
async def _chat(
160+
self,
161+
messages: list[ModelMessage],
162+
model_settings: ModelSettings | None,
163+
) -> ChatResponse:
164+
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
165+
model_settings = model_settings or {}
166+
return await self.client.chat(
167+
model=self.model_name,
168+
messages=cohere_messages,
169+
tools=self.tools or OMIT,
170+
max_tokens=model_settings.get('max_tokens', OMIT),
171+
temperature=model_settings.get('temperature', OMIT),
172+
p=model_settings.get('top_p', OMIT),
173+
)
174+
175+
@staticmethod
176+
def _process_response(response: ChatResponse) -> ModelResponse:
177+
"""Process a non-streamed response, and prepare a message to return."""
178+
parts: list[ModelResponsePart] = []
179+
if response.message.content is not None and len(response.message.content) > 0:
180+
# While Cohere's API returns a list, it only does that for future proofing
181+
# and currently only one item is being returned.
182+
choice = response.message.content[0]
183+
parts.append(TextPart(choice.text))
184+
for c in response.message.tool_calls or []:
185+
if c.function and c.function.name and c.function.arguments:
186+
parts.append(
187+
ToolCallPart.from_raw_args(
188+
tool_name=c.function.name,
189+
args=c.function.arguments,
190+
tool_call_id=c.id,
191+
)
192+
)
193+
return ModelResponse(parts=parts)
194+
195+
@classmethod
196+
def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
197+
"""Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
198+
if isinstance(message, ModelRequest):
199+
yield from cls._map_user_message(message)
200+
elif isinstance(message, ModelResponse):
201+
texts: list[str] = []
202+
tool_calls: list[ToolCallV2] = []
203+
for item in message.parts:
204+
if isinstance(item, TextPart):
205+
texts.append(item.content)
206+
elif isinstance(item, ToolCallPart):
207+
tool_calls.append(_map_tool_call(item))
208+
else:
209+
assert_never(item)
210+
message_param = AssistantChatMessageV2(role='assistant')
211+
if texts:
212+
message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
213+
if tool_calls:
214+
message_param.tool_calls = tool_calls
215+
yield message_param
216+
else:
217+
assert_never(message)
218+
219+
@classmethod
220+
def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
221+
for part in message.parts:
222+
if isinstance(part, SystemPromptPart):
223+
yield SystemChatMessageV2(role='system', content=part.content)
224+
elif isinstance(part, UserPromptPart):
225+
yield UserChatMessageV2(role='user', content=part.content)
226+
elif isinstance(part, ToolReturnPart):
227+
yield ToolChatMessageV2(
228+
role='tool',
229+
tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
230+
content=part.model_response_str(),
231+
)
232+
elif isinstance(part, RetryPromptPart):
233+
if part.tool_name is None:
234+
yield UserChatMessageV2(role='user', content=part.model_response())
235+
else:
236+
yield ToolChatMessageV2(
237+
role='tool',
238+
tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
239+
content=part.model_response(),
240+
)
241+
else:
242+
assert_never(part)
243+
244+
245+
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
246+
return ToolCallV2(
247+
id=_guard_tool_call_id(t=t, model_source='Cohere'),
248+
type='function',
249+
function=ToolCallV2Function(
250+
name=t.tool_name,
251+
arguments=t.args_as_json_str(),
252+
),
253+
)
254+
255+
256+
def _map_usage(response: ChatResponse) -> result.Usage:
257+
usage = response.usage
258+
if usage is None:
259+
return result.Usage()
260+
else:
261+
details: dict[str, int] = {}
262+
if usage.billed_units is not None:
263+
if usage.billed_units.input_tokens:
264+
details['input_tokens'] = int(usage.billed_units.input_tokens)
265+
if usage.billed_units.output_tokens:
266+
details['output_tokens'] = int(usage.billed_units.output_tokens)
267+
if usage.billed_units.search_units:
268+
details['search_units'] = int(usage.billed_units.search_units)
269+
if usage.billed_units.classifications:
270+
details['classifications'] = int(usage.billed_units.classifications)
271+
272+
request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None
273+
response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None
274+
return result.Usage(
275+
request_tokens=request_tokens,
276+
response_tokens=response_tokens,
277+
total_tokens=(request_tokens or 0) + (response_tokens or 0),
278+
details=details,
279+
)

pydantic_ai_slim/pydantic_ai/settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ModelSettings(TypedDict, total=False):
2424
* Anthropic
2525
* OpenAI
2626
* Groq
27+
* Cohere
2728
"""
2829

2930
temperature: float
@@ -40,6 +41,7 @@ class ModelSettings(TypedDict, total=False):
4041
* Anthropic
4142
* OpenAI
4243
* Groq
44+
* Cohere
4345
"""
4446

4547
top_p: float
@@ -55,6 +57,7 @@ class ModelSettings(TypedDict, total=False):
5557
* Anthropic
5658
* OpenAI
5759
* Groq
60+
* Cohere
5861
"""
5962

6063
timeout: float | Timeout

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ dependencies = [
4646
logfire = ["logfire>=2.3"]
4747
graph = ["pydantic-graph==0.0.19"]
4848
openai = ["openai>=1.54.3"]
49+
cohere = ["cohere>=5.13.4"]
4950
vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
5051
anthropic = ["anthropic>=0.40.0"]
5152
groq = ["groq>=0.12.0"]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ classifiers = [
3737
]
3838
requires-python = ">=3.9"
3939

40-
dependencies = ["pydantic-ai-slim[graph,openai,vertexai,groq,anthropic,mistral]==0.0.19"]
40+
dependencies = ["pydantic-ai-slim[graph,openai,vertexai,groq,anthropic,mistral,cohere]==0.0.19"]
4141

4242
[project.urls]
4343
Homepage = "https://ai.pydantic.dev"

0 commit comments

Comments
 (0)