Skip to content

Commit f4fc435

Browse files
Add Cohere docs and additional (live) tests (#810)
Co-authored-by: David Montague <[email protected]>
1 parent ba4a598 commit f4fc435

File tree

8 files changed

+103
-23
lines changed

8 files changed

+103
-23
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ jobs:
9898
--extra groq
9999
--extra anthropic
100100
--extra mistral
101+
--extra cohere
101102
pytest tests/test_live.py -v
102103
--durations=100
103104
env:
@@ -108,6 +109,7 @@ jobs:
108109
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
109110
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
110111
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
112+
CO_API_KEY: ${{ secrets.COHERE_API_KEY }}
111113
112114
test:
113115
name: test on ${{ matrix.python-version }}

docs/api/models/cohere.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# `pydantic_ai.models.cohere`
2+
3+
## Setup
4+
5+
For details on how to set up authentication with this model, see [model configuration for Cohere](../../models.md#cohere).
6+
7+
::: pydantic_ai.models.cohere

docs/install.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ pip/uv-add 'pydantic-ai-slim[openai]'
5353
* `anthropic` — installs `anthropic` [PyPI ↗](https://pypi.org/project/anthropic){:target="_blank"}
5454
* `groq` — installs `groq` [PyPI ↗](https://pypi.org/project/groq){:target="_blank"}
5555
* `mistral` — installs `mistralai` [PyPI ↗](https://pypi.org/project/mistralai){:target="_blank"}
56+
* `cohere` - installs `cohere` [PyPI ↗](https://pypi.org/project/cohere){:target="_blank"}
5657

5758
See the [models](models.md) documentation for information on which optional dependencies are required for each model.
5859

docs/models.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ PydanticAI is Model-agnostic and has built in support for the following model pr
77
* [Deepseek](#deepseek)
88
* [Groq](#groq)
99
* [Mistral](#mistral)
10+
* [Cohere](#cohere)
1011

1112
See [OpenAI-compatible models](#openai-compatible-models) for more examples on how to use models such as [OpenRouter](#openrouter), and [Grok (xAI)](#grok-xai) that support the OpenAI SDK.
1213

@@ -419,6 +420,63 @@ agent = Agent(model)
419420
...
420421
```
421422

423+
## Cohere
424+
425+
### Install
426+
427+
To use [`CohereModel`][pydantic_ai.models.cohere.CohereModel], you need to either install [`pydantic-ai`](install.md), or install [`pydantic-ai-slim`](install.md#slim-install) with the `cohere` optional group:
428+
429+
```bash
430+
pip/uv-add 'pydantic-ai-slim[cohere]'
431+
```
432+
433+
### Configuration
434+
435+
To use [Cohere](https://cohere.com/) through their API, go to [dashboard.cohere.com/api-keys](https://dashboard.cohere.com/api-keys) and follow your nose until you find the place to generate an API key.
436+
437+
[`NamedCohereModels`][pydantic_ai.models.cohere.NamedCohereModels] contains a list of the most popular Cohere models.
438+
439+
### Environment variable
440+
441+
Once you have the API key, you can set it as an environment variable:
442+
443+
```bash
444+
export CO_API_KEY='your-api-key'
445+
```
446+
447+
You can then use [`CohereModel`][pydantic_ai.models.cohere.CohereModel] by name:
448+
449+
```python {title="cohere_model_by_name.py"}
450+
from pydantic_ai import Agent
451+
452+
agent = Agent('cohere:command')
453+
...
454+
```
455+
456+
Or initialise the model directly with just the model name:
457+
458+
```python {title="cohere_model_init.py"}
459+
from pydantic_ai import Agent
460+
from pydantic_ai.models.cohere import CohereModel
461+
462+
model = CohereModel('command', api_key='your-api-key')
463+
agent = Agent(model)
464+
...
465+
```
466+
467+
### `api_key` argument
468+
469+
If you don't want to or can't set the environment variable, you can pass it at runtime via the [`api_key` argument][pydantic_ai.models.cohere.CohereModel.__init__]:
470+
471+
```python {title="cohere_model_api_key.py"}
472+
from pydantic_ai import Agent
473+
from pydantic_ai.models.cohere import CohereModel
474+
475+
model = CohereModel('command', api_key='your-api-key')
476+
agent = Agent(model)
477+
...
478+
```
479+
422480
## OpenAI-compatible Models
423481

424482
Many of the models are compatible with OpenAI API, and thus can be used with [`OpenAIModel`][pydantic_ai.models.openai.OpenAIModel] in PydanticAI.

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ nav:
5050
- api/models/base.md
5151
- api/models/openai.md
5252
- api/models/anthropic.md
53+
- api/models/cohere.md
5354
- api/models/gemini.md
5455
- api/models/vertexai.md
5556
- api/models/groq.md

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from collections.abc import Iterable
44
from dataclasses import dataclass, field
55
from itertools import chain
6-
from typing import Literal, TypeAlias, Union, cast
6+
from typing import Literal, Union, cast
77

88
from cohere import TextAssistantMessageContentItem
9+
from httpx import AsyncClient as AsyncHTTPClient
910
from typing_extensions import assert_never
1011

1112
from .. import result
@@ -51,24 +52,24 @@
5152
"you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
5253
) from _import_error
5354

54-
CohereModelName: TypeAlias = Union[
55-
str,
56-
Literal[
57-
'c4ai-aya-expanse-32b',
58-
'c4ai-aya-expanse-8b',
59-
'command',
60-
'command-light',
61-
'command-light-nightly',
62-
'command-nightly',
63-
'command-r',
64-
'command-r-03-2024',
65-
'command-r-08-2024',
66-
'command-r-plus',
67-
'command-r-plus-04-2024',
68-
'command-r-plus-08-2024',
69-
'command-r7b-12-2024',
70-
],
55+
NamedCohereModels = Literal[
56+
'c4ai-aya-expanse-32b',
57+
'c4ai-aya-expanse-8b',
58+
'command',
59+
'command-light',
60+
'command-light-nightly',
61+
'command-nightly',
62+
'command-r',
63+
'command-r-03-2024',
64+
'command-r-08-2024',
65+
'command-r-plus',
66+
'command-r-plus-04-2024',
67+
'command-r-plus-08-2024',
68+
'command-r7b-12-2024',
7169
]
70+
"""Latest / most popular named Cohere models."""
71+
72+
CohereModelName = Union[NamedCohereModels, str]
7273

7374

7475
class CohereModelSettings(ModelSettings):
@@ -96,23 +97,26 @@ def __init__(
9697
*,
9798
api_key: str | None = None,
9899
cohere_client: AsyncClientV2 | None = None,
100+
http_client: AsyncHTTPClient | None = None,
99101
):
100102
"""Initialize an Cohere model.
101103
102104
Args:
103105
model_name: The name of the Cohere model to use. List of model names
104106
available [here](https://docs.cohere.com/docs/models#command).
105107
api_key: The API key to use for authentication, if not provided, the
106-
`COHERE_API_KEY` environment variable will be used if available.
108+
`CO_API_KEY` environment variable will be used if available.
107109
cohere_client: An existing Cohere async client to use. If provided,
108-
`api_key` must be `None`.
110+
`api_key` and `http_client` must be `None`.
111+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
109112
"""
110113
self.model_name: CohereModelName = model_name
111114
if cohere_client is not None:
115+
assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
112116
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
113117
self.client = cohere_client
114118
else:
115-
self.client = AsyncClientV2(api_key=api_key) # type: ignore
119+
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
116120

117121
async def agent_model(
118122
self,

tests/models/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
'MistralModel',
5757
),
5858
(
59-
'COHERE_API_KEY',
59+
'CO_API_KEY',
6060
'cohere:command',
6161
'cohere:command',
6262
'cohere',

tests/test_live.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ def mistral(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
6666
return MistralModel('mistral-small-latest', http_client=http_client)
6767

6868

69+
def cohere(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
70+
from pydantic_ai.models.cohere import CohereModel
71+
72+
return CohereModel('command-r7b-12-2024', http_client=http_client)
73+
74+
6975
params = [
7076
pytest.param(openai, id='openai'),
7177
pytest.param(gemini, marks=pytest.mark.skip(reason='API seems very flaky'), id='gemini'),
@@ -74,6 +80,7 @@ def mistral(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
7480
pytest.param(anthropic, id='anthropic'),
7581
pytest.param(ollama, id='ollama'),
7682
pytest.param(mistral, id='mistral'),
83+
pytest.param(cohere, id='cohere'),
7784
]
7885
GetModel = Callable[[httpx.AsyncClient, Path], Model]
7986

@@ -95,7 +102,7 @@ async def test_text(http_client: httpx.AsyncClient, tmp_path: Path, get_model: G
95102
assert usage.total_tokens is not None and usage.total_tokens > 0
96103

97104

98-
stream_params = [p for p in params if p.id != 'anthropic']
105+
stream_params = [p for p in params if p.id != 'cohere']
99106

100107

101108
@pytest.mark.parametrize('get_model', stream_params)

0 commit comments

Comments
 (0)