Skip to content

Commit d273775

Browse files
authored
Add Cerebras provider (#2643)
1 parent 81f81df commit d273775

33 files changed

+793
-80
lines changed

docs/api/providers.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
::: pydantic_ai.providers.cohere
1818

19+
::: pydantic_ai.providers.cerebras.CerebrasProvider
20+
1921
::: pydantic_ai.providers.mistral.MistralProvider
2022

2123
::: pydantic_ai.providers.fireworks.FireworksProvider

docs/models/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ In addition, many providers are compatible with the OpenAI API, and can be used
2626
- [Azure AI Foundry](openai.md#azure-ai-foundry)
2727
- [Heroku](openai.md#heroku-ai)
2828
- [GitHub Models](openai.md#github-models)
29+
- [Cerebras](openai.md#cerebras)
2930

3031
Pydantic AI also comes with [`TestModel`](../api/models/test.md) and [`FunctionModel`](../api/models/function.md)
3132
for testing and development.

docs/models/openai.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,3 +530,36 @@ You can set the `HEROKU_INFERENCE_KEY` and `HEROKU_INFERENCE_URL` environment va
530530
export HEROKU_INFERENCE_KEY='your-heroku-inference-key'
531531
export HEROKU_INFERENCE_URL='https://us.inference.heroku.com'
532532
```
533+
534+
### Cerebras
535+
536+
To use [Cerebras](https://cerebras.ai/), you need to create an API key in the [Cerebras Console](https://cloud.cerebras.ai/).
537+
538+
Once you've set the `CEREBRAS_API_KEY` environment variable, you can run the following:
539+
540+
```python
541+
from pydantic_ai import Agent
542+
543+
agent = Agent('cerebras:llama3.3-70b')
544+
result = agent.run_sync('What is the capital of France?')
545+
print(result.output)
546+
#> The capital of France is Paris.
547+
```
548+
549+
If you need to configure the provider, you can use the [`CerebrasProvider`][pydantic_ai.providers.cerebras.CerebrasProvider] class:
550+
551+
```python
552+
from pydantic_ai import Agent
553+
from pydantic_ai.models.openai import OpenAIChatModel
554+
from pydantic_ai.providers.cerebras import CerebrasProvider
555+
556+
model = OpenAIChatModel(
557+
'llama3.3-70b',
558+
provider=CerebrasProvider(api_key='your-cerebras-api-key'),
559+
)
560+
agent = Agent(model)
561+
562+
result = agent.run_sync('What is the capital of France?')
563+
print(result.output)
564+
#> The capital of France is Paris.
565+
```

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@
111111
'bedrock:mistral.mixtral-8x7b-instruct-v0:1',
112112
'bedrock:mistral.mistral-large-2402-v1:0',
113113
'bedrock:mistral.mistral-large-2407-v1:0',
114+
'cerebras:gpt-oss-120b',
115+
'cerebras:llama3.1-8b',
116+
'cerebras:llama-3.3-70b',
117+
'cerebras:llama-4-scout-17b-16e-instruct',
118+
'cerebras:llama-4-maverick-17b-128e-instruct',
119+
'cerebras:qwen-3-235b-a22b-instruct-2507',
120+
'cerebras:qwen-3-32b',
121+
'cerebras:qwen-3-coder-480b',
122+
'cerebras:qwen-3-235b-a22b-thinking-2507',
114123
'claude-3-5-haiku-20241022',
115124
'claude-3-5-haiku-latest',
116125
'claude-3-5-sonnet-20240620',
@@ -695,18 +704,19 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
695704

696705
return CohereModel(model_name, provider=provider)
697706
elif provider in (
698-
'openai',
699-
'openai-chat',
700-
'deepseek',
701707
'azure',
702-
'openrouter',
703-
'vercel',
708+
'deepseek',
709+
'cerebras',
710+
'fireworks',
711+
'github',
704712
'grok',
713+
'heroku',
705714
'moonshotai',
706-
'fireworks',
715+
'openai',
716+
'openai-chat',
717+
'openrouter',
707718
'together',
708-
'heroku',
709-
'github',
719+
'vercel',
710720
):
711721
from .openai import OpenAIChatModel
712722

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -211,19 +211,20 @@ def __init__(
211211
model_name: OpenAIModelName,
212212
*,
213213
provider: Literal[
214-
'openai',
215-
'openai-chat',
216-
'deepseek',
217214
'azure',
218-
'openrouter',
219-
'moonshotai',
220-
'vercel',
221-
'grok',
215+
'deepseek',
216+
'cerebras',
222217
'fireworks',
223-
'together',
224-
'heroku',
225218
'github',
219+
'grok',
220+
'heroku',
221+
'moonshotai',
226222
'ollama',
223+
'openai',
224+
'openai-chat',
225+
'openrouter',
226+
'together',
227+
'vercel',
227228
]
228229
| Provider[AsyncOpenAI] = 'openai',
229230
profile: ModelProfileSpec | None = None,
@@ -237,19 +238,20 @@ def __init__(
237238
model_name: OpenAIModelName,
238239
*,
239240
provider: Literal[
240-
'openai',
241-
'openai-chat',
242-
'deepseek',
243241
'azure',
244-
'openrouter',
245-
'moonshotai',
246-
'vercel',
247-
'grok',
242+
'deepseek',
243+
'cerebras',
248244
'fireworks',
249-
'together',
250-
'heroku',
251245
'github',
246+
'grok',
247+
'heroku',
248+
'moonshotai',
252249
'ollama',
250+
'openai',
251+
'openai-chat',
252+
'openrouter',
253+
'together',
254+
'vercel',
253255
]
254256
| Provider[AsyncOpenAI] = 'openai',
255257
profile: ModelProfileSpec | None = None,
@@ -262,19 +264,20 @@ def __init__(
262264
model_name: OpenAIModelName,
263265
*,
264266
provider: Literal[
265-
'openai',
266-
'openai-chat',
267-
'deepseek',
268267
'azure',
269-
'openrouter',
270-
'moonshotai',
271-
'vercel',
272-
'grok',
268+
'deepseek',
269+
'cerebras',
273270
'fireworks',
274-
'together',
275-
'heroku',
276271
'github',
272+
'grok',
273+
'heroku',
274+
'moonshotai',
277275
'ollama',
276+
'openai',
277+
'openai-chat',
278+
'openrouter',
279+
'together',
280+
'vercel',
278281
]
279282
| Provider[AsyncOpenAI] = 'openai',
280283
profile: ModelProfileSpec | None = None,
@@ -402,6 +405,11 @@ async def _completions_create(
402405
): # pragma: no branch
403406
response_format = {'type': 'json_object'}
404407

408+
unsupported_model_settings = OpenAIModelProfile.from_profile(self.profile).openai_unsupported_model_settings
409+
for setting in unsupported_model_settings:
410+
model_settings.pop(setting, None)
411+
412+
# TODO(Marcelo): Deprecate this in favor of `openai_unsupported_model_settings`.
405413
sampling_settings = (
406414
model_settings
407415
if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings
@@ -646,9 +654,7 @@ async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.C
646654
)
647655
elif isinstance(part, RetryPromptPart):
648656
if part.tool_name is None:
649-
yield chat.ChatCompletionUserMessageParam( # pragma: no cover
650-
role='user', content=part.model_response()
651-
)
657+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
652658
else:
653659
yield chat.ChatCompletionToolMessageParam(
654660
role='tool',
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from __future__ import annotations as _annotations
2+
3+
from . import ModelProfile
4+
from .openai import OpenAIModelProfile, openai_model_profile
5+
6+
7+
def harmony_model_profile(model_name: str) -> ModelProfile | None:
8+
"""The model profile for the OpenAI Harmony Response format.
9+
10+
See <https://cookbook.openai.com/articles/openai-harmony> for more details.
11+
"""
12+
profile = openai_model_profile(model_name)
13+
return OpenAIModelProfile(openai_supports_tool_choice_required=False).update(profile)

pydantic_ai_slim/pydantic_ai/profiles/openai.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
import re
4+
from collections.abc import Sequence
45
from dataclasses import dataclass
56
from typing import Any, Literal
67

@@ -20,9 +21,13 @@ class OpenAIModelProfile(ModelProfile):
2021
openai_supports_strict_tool_definition: bool = True
2122
"""This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions."""
2223

24+
# TODO(Marcelo): Deprecate this in favor of `openai_unsupported_model_settings`.
2325
openai_supports_sampling_settings: bool = True
2426
"""Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models."""
2527

28+
openai_unsupported_model_settings: Sequence[str] = ()
29+
"""A list of model settings that are not supported by the model."""
30+
2631
# Some OpenAI-compatible providers (e.g. MoonshotAI) currently do **not** accept
2732
# `tool_choice="required"`. This flag lets the calling model know whether it's
2833
# safe to pass that value along. Default is `True` to preserve existing

pydantic_ai_slim/pydantic_ai/profiles/qwen.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
from __future__ import annotations as _annotations
22

3+
from ..profiles.openai import OpenAIModelProfile
34
from . import InlineDefsJsonSchemaTransformer, ModelProfile
45

56

67
def qwen_model_profile(model_name: str) -> ModelProfile | None:
78
"""Get the model profile for a Qwen model."""
9+
if model_name.startswith('qwen-3-coder'):
10+
return OpenAIModelProfile(
11+
json_schema_transformer=InlineDefsJsonSchemaTransformer,
12+
openai_supports_tool_choice_required=False,
13+
openai_supports_strict_tool_definition=False,
14+
ignore_streamed_leading_whitespace=True,
15+
)
816
return ModelProfile(
917
json_schema_transformer=InlineDefsJsonSchemaTransformer,
1018
ignore_streamed_leading_whitespace=True,

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
9595
from .mistral import MistralProvider
9696

9797
return MistralProvider
98+
elif provider == 'cerebras':
99+
from .cerebras import CerebrasProvider
100+
101+
return CerebrasProvider
98102
elif provider == 'cohere':
99103
from .cohere import CohereProvider
100104

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from __future__ import annotations as _annotations
2+
3+
import os
4+
from typing import overload
5+
6+
import httpx
7+
8+
from pydantic_ai.exceptions import UserError
9+
from pydantic_ai.models import cached_async_http_client
10+
from pydantic_ai.profiles import ModelProfile
11+
from pydantic_ai.profiles.harmony import harmony_model_profile
12+
from pydantic_ai.profiles.meta import meta_model_profile
13+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
14+
from pydantic_ai.profiles.qwen import qwen_model_profile
15+
from pydantic_ai.providers import Provider
16+
17+
try:
18+
from openai import AsyncOpenAI
19+
except ImportError as _import_error: # pragma: no cover
20+
raise ImportError(
21+
'Please install the `openai` package to use the Cerebras provider, '
22+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
23+
) from _import_error
24+
25+
26+
class CerebrasProvider(Provider[AsyncOpenAI]):
27+
"""Provider for Cerebras API."""
28+
29+
@property
30+
def name(self) -> str:
31+
return 'cerebras'
32+
33+
@property
34+
def base_url(self) -> str:
35+
return 'https://api.cerebras.ai/v1'
36+
37+
@property
38+
def client(self) -> AsyncOpenAI:
39+
return self._client
40+
41+
def model_profile(self, model_name: str) -> ModelProfile | None:
42+
prefix_to_profile = {'llama': meta_model_profile, 'qwen': qwen_model_profile, 'gpt-oss': harmony_model_profile}
43+
44+
profile = None
45+
for prefix, profile_func in prefix_to_profile.items():
46+
model_name = model_name.lower()
47+
if model_name.startswith(prefix):
48+
profile = profile_func(model_name)
49+
50+
# According to https://inference-docs.cerebras.ai/resources/openai#currently-unsupported-openai-features,
51+
# Cerebras doesn't support some model settings.
52+
unsupported_model_settings = (
53+
'frequency_penalty',
54+
'logit_bias',
55+
'presence_penalty',
56+
'parallel_tool_calls',
57+
'service_tier',
58+
)
59+
return OpenAIModelProfile(
60+
json_schema_transformer=OpenAIJsonSchemaTransformer,
61+
openai_unsupported_model_settings=unsupported_model_settings,
62+
).update(profile)
63+
64+
@overload
65+
def __init__(self) -> None: ...
66+
67+
@overload
68+
def __init__(self, *, api_key: str) -> None: ...
69+
70+
@overload
71+
def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
72+
73+
@overload
74+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
75+
76+
def __init__(
77+
self,
78+
*,
79+
api_key: str | None = None,
80+
openai_client: AsyncOpenAI | None = None,
81+
http_client: httpx.AsyncClient | None = None,
82+
) -> None:
83+
api_key = api_key or os.getenv('CEREBRAS_API_KEY')
84+
if not api_key and openai_client is None:
85+
raise UserError(
86+
'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` '
87+
'to use the Cerebras provider.'
88+
)
89+
90+
if openai_client is not None:
91+
self._client = openai_client
92+
elif http_client is not None:
93+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
94+
else:
95+
http_client = cached_async_http_client(provider='cerebras')
96+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)

0 commit comments

Comments
 (0)