Skip to content

Commit a590014

Browse files
hrahmadi71Kludex
andauthored
Add cohere provider class (#1225)
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent e328fc3 commit a590014

File tree

10 files changed

+280
-18
lines changed

10 files changed

+280
-18
lines changed

docs/models.md

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -611,20 +611,39 @@ Or initialise the model directly with just the model name:
611611
from pydantic_ai import Agent
612612
from pydantic_ai.models.cohere import CohereModel
613613

614-
model = CohereModel('command', api_key='your-api-key')
614+
model = CohereModel('command')
615615
agent = Agent(model)
616616
...
617617
```
618618

619-
### `api_key` argument
619+
### `provider` argument
620620

621-
If you don't want to or can't set the environment variable, you can pass it at runtime via the [`api_key` argument][pydantic_ai.models.cohere.CohereModel.__init__]:
621+
You can provide a custom [`Provider`][pydantic_ai.providers.Provider] via the [`provider` argument][pydantic_ai.models.cohere.CohereModel.__init__]:
622622

623-
```python {title="cohere_model_api_key.py"}
623+
```python {title="cohere_model_provider.py"}
624624
from pydantic_ai import Agent
625625
from pydantic_ai.models.cohere import CohereModel
626+
from pydantic_ai.providers.cohere import CohereProvider
626627

627-
model = CohereModel('command', api_key='your-api-key')
628+
model = CohereModel('command', provider=CohereProvider(api_key='your-api-key'))
629+
agent = Agent(model)
630+
...
631+
```
632+
633+
You can also customize the `CohereProvider` with a custom `http_client`:
634+
635+
```python {title="cohere_model_custom_provider.py"}
636+
from httpx import AsyncClient
637+
638+
from pydantic_ai import Agent
639+
from pydantic_ai.models.cohere import CohereModel
640+
from pydantic_ai.providers.cohere import CohereProvider
641+
642+
custom_http_client = AsyncClient(timeout=30)
643+
model = CohereModel(
644+
'command',
645+
provider=CohereProvider(api_key='your-api-key', http_client=custom_http_client),
646+
)
628647
agent = Agent(model)
629648
...
630649
```

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,8 +402,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
402402
if provider == 'cohere':
403403
from .cohere import CohereModel
404404

405-
# TODO(Marcelo): Missing provider API.
406-
return CohereModel(model_name)
405+
return CohereModel(model_name, provider=provider)
407406
elif provider in ('deepseek', 'openai'):
408407
from .openai import OpenAIModel
409408

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from collections.abc import Iterable
44
from dataclasses import dataclass, field
55
from itertools import chain
6-
from typing import Literal, Union, cast
6+
from typing import Literal, Union, cast, overload
77

88
from cohere import TextAssistantMessageContentItem
99
from httpx import AsyncClient as AsyncHTTPClient
10-
from typing_extensions import assert_never
10+
from typing_extensions import assert_never, deprecated
1111

1212
from .. import ModelHTTPError, result
1313
from .._utils import guard_tool_call_id as _guard_tool_call_id
@@ -23,11 +23,13 @@
2323
ToolReturnPart,
2424
UserPromptPart,
2525
)
26+
from ..providers import Provider, infer_provider
2627
from ..settings import ModelSettings
2728
from ..tools import ToolDefinition
2829
from . import (
2930
Model,
3031
ModelRequestParameters,
32+
cached_async_http_client,
3133
check_allow_model_requests,
3234
)
3335

@@ -100,10 +102,34 @@ class CohereModel(Model):
100102
_model_name: CohereModelName = field(repr=False)
101103
_system: str = field(default='cohere', repr=False)
102104

105+
@overload
103106
def __init__(
104107
self,
105108
model_name: CohereModelName,
106109
*,
110+
provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
111+
api_key: None = None,
112+
cohere_client: None = None,
113+
http_client: None = None,
114+
) -> None: ...
115+
116+
@deprecated('Use the `provider` parameter instead of `api_key`, `cohere_client`, and `http_client`.')
117+
@overload
118+
def __init__(
119+
self,
120+
model_name: CohereModelName,
121+
*,
122+
provider: None = None,
123+
api_key: str | None = None,
124+
cohere_client: AsyncClientV2 | None = None,
125+
http_client: AsyncHTTPClient | None = None,
126+
) -> None: ...
127+
128+
def __init__(
129+
self,
130+
model_name: CohereModelName,
131+
*,
132+
provider: Literal['cohere'] | Provider[AsyncClientV2] | None = None,
107133
api_key: str | None = None,
108134
cohere_client: AsyncClientV2 | None = None,
109135
http_client: AsyncHTTPClient | None = None,
@@ -113,19 +139,27 @@ def __init__(
113139
Args:
114140
model_name: The name of the Cohere model to use. List of model names
115141
available [here](https://docs.cohere.com/docs/models#command).
142+
provider: The provider to use for authentication and API access. Can be either the string
143+
'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
144+
created using the other parameters.
116145
api_key: The API key to use for authentication, if not provided, the
117146
`CO_API_KEY` environment variable will be used if available.
118147
cohere_client: An existing Cohere async client to use. If provided,
119148
`api_key` and `http_client` must be `None`.
120149
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
121150
"""
122151
self._model_name: CohereModelName = model_name
123-
if cohere_client is not None:
152+
153+
if provider is not None:
154+
if isinstance(provider, str):
155+
provider = infer_provider(provider)
156+
self.client = provider.client
157+
elif cohere_client is not None:
124158
assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
125159
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
126160
self.client = cohere_client
127161
else:
128-
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client)
162+
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client or cached_async_http_client())
129163

130164
@property
131165
def base_url(self) -> str:

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,9 @@ def infer_provider(provider: str) -> Provider[Any]:
7777
from .mistral import MistralProvider
7878

7979
return MistralProvider()
80+
elif provider == 'cohere':
81+
from .cohere import CohereProvider
82+
83+
return CohereProvider()
8084
else: # pragma: no cover
8185
raise ValueError(f'Unknown provider: {provider}')
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from __future__ import annotations as _annotations
2+
3+
import os
4+
5+
from httpx import AsyncClient as AsyncHTTPClient
6+
7+
from pydantic_ai.models import cached_async_http_client
8+
9+
try:
10+
from cohere import AsyncClientV2
11+
except ImportError as _import_error: # pragma: no cover
12+
raise ImportError(
13+
'Please install the `cohere` package to use the Cohere provider, '
14+
'you can use the `cohere` optional group — `pip install "pydantic-ai-slim[cohere]"`'
15+
) from _import_error
16+
17+
18+
from . import Provider
19+
20+
21+
class CohereProvider(Provider[AsyncClientV2]):
22+
"""Provider for Cohere API."""
23+
24+
@property
25+
def name(self) -> str:
26+
return 'cohere'
27+
28+
@property
29+
def base_url(self) -> str:
30+
client_wrapper = self.client._client_wrapper # type: ignore
31+
return str(client_wrapper.get_base_url())
32+
33+
@property
34+
def client(self) -> AsyncClientV2:
35+
return self._client
36+
37+
def __init__(
38+
self,
39+
*,
40+
api_key: str | None = None,
41+
cohere_client: AsyncClientV2 | None = None,
42+
http_client: AsyncHTTPClient | None = None,
43+
) -> None:
44+
"""Create a new Cohere provider.
45+
46+
Args:
47+
api_key: The API key to use for authentication, if not provided, the `CO_API_KEY` environment variable
48+
will be used if available.
49+
cohere_client: An existing
50+
[AsyncClientV2](https://github.com/cohere-ai/cohere-python)
51+
client to use. If provided, `api_key` and `http_client` must be `None`.
52+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
53+
"""
54+
if cohere_client is not None:
55+
assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
56+
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
57+
self._client = cohere_client
58+
else:
59+
api_key = api_key or os.environ.get('CO_API_KEY')
60+
if api_key is None:
61+
raise ValueError(
62+
'Set the `CO_API_KEY` environment variable or pass it via `CohereProvider(api_key=...)`'
63+
'to use the Cohere provider.'
64+
)
65+
66+
base_url = os.environ.get('CO_BASE_URL')
67+
if http_client is not None:
68+
self._client = AsyncClientV2(api_key=api_key, httpx_client=http_client, base_url=base_url)
69+
else:
70+
self._client = AsyncClientV2(
71+
api_key=api_key, httpx_client=cached_async_http_client(), base_url=base_url
72+
)

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,11 @@ def anthropic_api_key() -> str:
246246
return os.getenv('ANTHROPIC_API_KEY', 'mock-api-key')
247247

248248

249+
@pytest.fixture(scope='session')
250+
def co_api_key() -> str:
251+
return os.getenv('CO_API_KEY', 'mock-api-key')
252+
253+
249254
@pytest.fixture
250255
def mock_snapshot_id(mocker: MockerFixture):
251256
i = 0
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- '*/*'
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '93'
12+
content-type:
13+
- application/json
14+
host:
15+
- api.cohere.com
16+
method: POST
17+
parsed_body:
18+
messages:
19+
- content: hello
20+
role: user
21+
model: command-r7b-12-2024
22+
stream: false
23+
uri: https://api.cohere.com/v2/chat
24+
response:
25+
headers:
26+
access-control-expose-headers:
27+
- X-Debug-Trace-ID
28+
alt-svc:
29+
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
30+
cache-control:
31+
- no-cache, no-store, no-transform, must-revalidate, private, max-age=0
32+
content-length:
33+
- '286'
34+
content-type:
35+
- application/json
36+
expires:
37+
- Thu, 01 Jan 1970 00:00:00 UTC
38+
num_chars:
39+
- '2583'
40+
num_tokens:
41+
- '10'
42+
pragma:
43+
- no-cache
44+
vary:
45+
- Origin
46+
parsed_body:
47+
finish_reason: COMPLETE
48+
id: f17a5f6c-1734-4098-bd0d-733ef000ac7b
49+
message:
50+
content:
51+
- text: Hello! How can I assist you today?
52+
type: text
53+
role: assistant
54+
usage:
55+
billed_units:
56+
input_tokens: 1
57+
output_tokens: 9
58+
tokens:
59+
input_tokens: 496
60+
output_tokens: 11
61+
status:
62+
code: 200
63+
message: OK
64+
version: 1

tests/models/test_cohere.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from cohere.core.api_error import ApiError
3939

4040
from pydantic_ai.models.cohere import CohereModel
41+
from pydantic_ai.providers.cohere import CohereProvider
4142

4243
# note: we use Union here for compatibility with Python 3.9
4344
MockChatResponse = Union[ChatResponse, Exception]
@@ -49,7 +50,7 @@
4950

5051

5152
def test_init():
52-
m = CohereModel('command-r7b-12-2024', api_key='foobar')
53+
m = CohereModel('command-r7b-12-2024', provider=CohereProvider(api_key='foobar'))
5354
assert m.model_name == 'command-r7b-12-2024'
5455
assert m.system == 'cohere'
5556
assert m.base_url == 'https://api.cohere.com'
@@ -96,7 +97,7 @@ async def test_request_simple_success(allow_model_requests: None):
9697
)
9798
)
9899
mock_client = MockAsyncClientV2.create_mock(c)
99-
m = CohereModel('command-r7b-12-2024', cohere_client=mock_client)
100+
m = CohereModel('command-r7b-12-2024', provider=CohereProvider(cohere_client=mock_client))
100101
agent = Agent(m)
101102

102103
result = await agent.run('hello')
@@ -135,7 +136,7 @@ async def test_request_simple_usage(allow_model_requests: None):
135136
),
136137
)
137138
mock_client = MockAsyncClientV2.create_mock(c)
138-
m = CohereModel('command-r7b-12-2024', cohere_client=mock_client)
139+
m = CohereModel('command-r7b-12-2024', provider=CohereProvider(cohere_client=mock_client))
139140
agent = Agent(m)
140141

141142
result = await agent.run('Hello')
@@ -169,7 +170,7 @@ async def test_request_structured_response(allow_model_requests: None):
169170
)
170171
)
171172
mock_client = MockAsyncClientV2.create_mock(c)
172-
m = CohereModel('command-r7b-12-2024', cohere_client=mock_client)
173+
m = CohereModel('command-r7b-12-2024', provider=CohereProvider(cohere_client=mock_client))
173174
agent = Agent(m, result_type=list[int])
174175

175176
result = await agent.run('Hello')
@@ -243,7 +244,7 @@ async def test_request_tool_call(allow_model_requests: None):
243244
),
244245
]
245246
mock_client = MockAsyncClientV2.create_mock(responses)
246-
m = CohereModel('command-r7b-12-2024', cohere_client=mock_client)
247+
m = CohereModel('command-r7b-12-2024', provider=CohereProvider(cohere_client=mock_client))
247248
agent = Agent(m, system_prompt='this is the system prompt')
248249

249250
@agent.tool_plain
@@ -326,7 +327,7 @@ async def get_location(loc_name: str) -> str:
326327
async def test_multimodal(allow_model_requests: None):
327328
c = completion_message(AssistantMessageResponse(content=[TextAssistantMessageResponseContentItem(text='world')]))
328329
mock_client = MockAsyncClientV2.create_mock(c)
329-
m = CohereModel('command-r7b-12-2024', cohere_client=mock_client)
330+
m = CohereModel('command-r7b-12-2024', provider=CohereProvider(cohere_client=mock_client))
330331
agent = Agent(m)
331332

332333
with pytest.raises(RuntimeError, match='Cohere does not yet support multi-modal inputs.'):
@@ -347,8 +348,16 @@ def test_model_status_error(allow_model_requests: None) -> None:
347348
body={'error': 'test error'},
348349
)
349350
)
350-
m = CohereModel('command-r', cohere_client=mock_client)
351+
m = CohereModel('command-r', provider=CohereProvider(cohere_client=mock_client))
351352
agent = Agent(m)
352353
with pytest.raises(ModelHTTPError) as exc_info:
353354
agent.run_sync('hello')
354355
assert str(exc_info.value) == snapshot("status_code: 500, model_name: command-r, body: {'error': 'test error'}")
356+
357+
358+
@pytest.mark.vcr()
359+
async def test_request_simple_success_with_vcr(allow_model_requests: None, co_api_key: str):
360+
m = CohereModel('command-r7b-12-2024', provider=CohereProvider(api_key=co_api_key))
361+
agent = Agent(m)
362+
result = await agent.run('hello')
363+
assert result.data == snapshot('Hello! How can I assist you today?')

0 commit comments

Comments
 (0)