Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions docs/models/cerebras.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Cerebras

## Install

To use `CerebrasModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group (which installs `openai`):

```bash
pip install "pydantic-ai-slim[cerebras]"
```

or

```bash
uv add "pydantic-ai-slim[cerebras]"
```

## Configuration

To use [Cerebras](https://cerebras.ai/) through their API, go to [cloud.cerebras.ai](https://cloud.cerebras.ai/?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) and follow your nose until you find the place to generate an API key.

`CerebrasModelName` contains a list of available Cerebras models.

## Environment variable

Once you have the API key, you can set it as an environment variable:

```bash
export CEREBRAS_API_KEY='your-api-key'
```

You can then use `CerebrasModel` by name:

```python
from pydantic_ai import Agent

agent = Agent('cerebras:llama-3.3-70b')
...
```

Or initialise the model directly with just the model name:

```python
from pydantic_ai import Agent
from pydantic_ai.models.cerebras import CerebrasModel

model = CerebrasModel('llama-3.3-70b')
agent = Agent(model)
...
```

## `provider` argument

You can provide a custom `Provider` via the `provider` argument:

```python
from pydantic_ai import Agent
from pydantic_ai.models.cerebras import CerebrasModel
from pydantic_ai.providers.cerebras import CerebrasProvider

model = CerebrasModel(
'llama-3.3-70b', provider=CerebrasProvider(api_key='your-api-key')
)
agent = Agent(model)
...
```

You can also customize the `CerebrasProvider` with a custom `httpx.AsyncHTTPClient`:

```python
from httpx import AsyncClient

from pydantic_ai import Agent
from pydantic_ai.models.cerebras import CerebrasModel
from pydantic_ai.providers.cerebras import CerebrasProvider

custom_http_client = AsyncClient(timeout=30)
model = CerebrasModel(
'llama-3.3-70b',
provider=CerebrasProvider(api_key='your-api-key', http_client=custom_http_client),
)
agent = Agent(model)
...
```
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ nav:
- models/anthropic.md
- models/google.md
- models/bedrock.md
- models/cerebras.md
- models/cohere.md
- models/groq.md
- models/mistral.md
Expand Down
8 changes: 5 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@
'cerebras:llama-3.3-70b',
'cerebras:llama3.1-8b',
'cerebras:qwen-3-235b-a22b-instruct-2507',
'cerebras:qwen-3-235b-a22b-thinking-2507',
'cerebras:qwen-3-32b',
'cerebras:zai-glm-4.6',
'cohere:c4ai-aya-expanse-32b',
Expand Down Expand Up @@ -800,7 +799,6 @@ def infer_model( # noqa: C901
'openai',
'azure',
'deepseek',
'cerebras',
'fireworks',
'github',
'grok',
Expand All @@ -818,7 +816,11 @@ def infer_model( # noqa: C901
elif model_kind in ('google-gla', 'google-vertex'):
model_kind = 'google'

if model_kind == 'openai-chat':
if model_kind == 'cerebras':
from .cerebras import CerebrasModel

return CerebrasModel(model_name, provider=provider) # type: ignore[arg-type]
elif model_kind == 'openai-chat':
from .openai import OpenAIChatModel

return OpenAIChatModel(model_name, provider=provider)
Expand Down
58 changes: 58 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/cerebras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Cerebras model implementation using OpenAI-compatible API."""

from __future__ import annotations as _annotations

from dataclasses import dataclass
from typing import Literal

from ..profiles import ModelProfileSpec
from ..providers import Provider
from ..settings import ModelSettings
from .openai import OpenAIChatModel

try:
from openai import AsyncOpenAI
except ImportError as _import_error: # pragma: no cover
raise ImportError(
'Please install the `openai` package to use the Cerebras model, '
'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"'
) from _import_error

__all__ = ('CerebrasModel', 'CerebrasModelName')

CerebrasModelName = Literal[
'gpt-oss-120b',
'llama-3.3-70b',
'llama3.1-8b',
'qwen-3-235b-a22b-instruct-2507',
'qwen-3-32b',
'zai-glm-4.6',
]


@dataclass(init=False)
class CerebrasModel(OpenAIChatModel):
"""A model that uses Cerebras's OpenAI-compatible API.

Cerebras provides ultra-fast inference powered by the Wafer-Scale Engine (WSE).

Apart from `__init__`, all methods are private or match those of the base class.
"""

def __init__(
self,
model_name: CerebrasModelName,
*,
provider: Literal['cerebras'] | Provider[AsyncOpenAI] = 'cerebras',
profile: ModelProfileSpec | None = None,
settings: ModelSettings | None = None,
):
"""Initialize a Cerebras model.

Args:
model_name: The name of the Cerebras model to use.
provider: The provider to use. Defaults to 'cerebras'.
profile: The model profile to use. Defaults to a profile based on the model name.
settings: Model-specific settings that will be used as defaults for this model.
"""
super().__init__(model_name, provider=provider, profile=profile, settings=settings)
42 changes: 24 additions & 18 deletions pydantic_ai_slim/pydantic_ai/providers/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic_ai.models import cached_async_http_client
from pydantic_ai.profiles.harmony import harmony_model_profile
from pydantic_ai.profiles.meta import meta_model_profile
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
from pydantic_ai.profiles.openai import OpenAIModelProfile
from pydantic_ai.profiles.qwen import qwen_model_profile
from pydantic_ai.providers import Provider

Expand All @@ -19,7 +19,7 @@
except ImportError as _import_error: # pragma: no cover
raise ImportError(
'Please install the `openai` package to use the Cerebras provider, '
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"`'
) from _import_error


Expand All @@ -39,27 +39,22 @@ def client(self) -> AsyncOpenAI:
return self._client

def model_profile(self, model_name: str) -> ModelProfile | None:
prefix_to_profile = {'llama': meta_model_profile, 'qwen': qwen_model_profile, 'gpt-oss': harmony_model_profile}
prefix_to_profile = {
'llama': meta_model_profile,
'qwen': qwen_model_profile,
'gpt-oss': harmony_model_profile,
}

profile = None
model_name_lower = model_name.lower()
for prefix, profile_func in prefix_to_profile.items():
model_name = model_name.lower()
if model_name.startswith(prefix):
if model_name_lower.startswith(prefix):
profile = profile_func(model_name)
break

# According to https://inference-docs.cerebras.ai/resources/openai#currently-unsupported-openai-features,
# Cerebras doesn't support some model settings.
unsupported_model_settings = (
'frequency_penalty',
'logit_bias',
'presence_penalty',
'parallel_tool_calls',
'service_tier',
)
return OpenAIModelProfile(
json_schema_transformer=OpenAIJsonSchemaTransformer,
openai_unsupported_model_settings=unsupported_model_settings,
).update(profile)
# Wrap in OpenAIModelProfile with web search disabled
# Cerebras doesn't support web search
return OpenAIModelProfile(openai_chat_supports_web_search=False).update(profile)

@overload
def __init__(self) -> None: ...
Expand All @@ -70,6 +65,9 @@ def __init__(self, *, api_key: str) -> None: ...
@overload
def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...

@overload
def __init__(self, *, http_client: httpx.AsyncClient) -> None: ...

@overload
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...

Expand All @@ -80,6 +78,14 @@ def __init__(
openai_client: AsyncOpenAI | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
"""Create a new Cerebras provider.

Args:
api_key: The API key to use for authentication, if not provided, the `CEREBRAS_API_KEY` environment variable
will be used if available.
openai_client: An existing `AsyncOpenAI` client to use. If provided, `api_key` and `http_client` must be `None`.
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
"""
api_key = api_key or os.getenv('CEREBRAS_API_KEY')
if not api_key and openai_client is None:
raise UserError(
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ cohere = ["cohere>=5.18.0; platform_system != 'Emscripten'"]
vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
google = ["google-genai>=1.51.0"]
anthropic = ["anthropic>=0.70.0"]
cerebras = ["openai>=1.107.2"]
groq = ["groq>=0.25.0"]
mistral = ["mistralai>=1.9.10"]
bedrock = ["boto3>=1.40.14"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ interactions:
alt-svc:
- h3=":443"; ma=86400
content-length:
- '570'
- '479'
content-type:
- application/json
referrer-policy:
Expand All @@ -133,32 +133,28 @@ interactions:
- max-age=3600; includeSubDomains
parsed_body:
data:
- created: 0
id: qwen-3-235b-a22b-thinking-2507
object: model
owned_by: Cerebras
- created: 0
id: llama-3.3-70b
object: model
owned_by: Cerebras
- created: 0
id: qwen-3-235b-a22b-instruct-2507
id: llama3.1-8b
object: model
owned_by: Cerebras
- created: 0
id: qwen-3-32b
id: zai-glm-4.6
object: model
owned_by: Cerebras
- created: 0
id: zai-glm-4.6
id: qwen-3-32b
object: model
owned_by: Cerebras
- created: 0
id: gpt-oss-120b
object: model
owned_by: Cerebras
- created: 0
id: llama3.1-8b
id: qwen-3-235b-a22b-instruct-2507
object: model
owned_by: Cerebras
object: list
Expand Down
73 changes: 73 additions & 0 deletions tests/models/test_cerebras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import re

import pytest

from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer
from pydantic_ai.exceptions import UserError
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile

from ..conftest import TestEnv, try_import

with try_import() as imports_successful:
from openai import AsyncOpenAI

from pydantic_ai.models.cerebras import CerebrasModel
from pydantic_ai.providers.cerebras import CerebrasProvider

pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed')


def test_cerebras_model_init():
model = CerebrasModel('llama-3.3-70b', provider=CerebrasProvider(api_key='test_key'))
assert model.model_name == 'llama-3.3-70b'
assert isinstance(model._provider, CerebrasProvider) # type: ignore[reportPrivateUsage]
assert model._provider.client.api_key == 'test_key' # type: ignore[reportPrivateUsage]


def test_cerebras_model_profile():
provider = CerebrasProvider(api_key='test_key')

# Test Llama model
model = CerebrasModel('llama-3.3-70b', provider=provider)
profile = model.profile
assert isinstance(profile, OpenAIModelProfile)
assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False

# Test Qwen model
model = CerebrasModel('qwen-3-235b-a22b-instruct-2507', provider=provider)
profile = model.profile
assert isinstance(profile, OpenAIModelProfile)
assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False

# Test GPT-OSS model
model = CerebrasModel('gpt-oss-120b', provider=provider)
profile = model.profile
assert isinstance(profile, OpenAIModelProfile)
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False

# Test unknown model - use zai-glm which is valid but won't match any prefix
model = CerebrasModel('zai-glm-4.6', provider=provider)
profile = model.profile
assert isinstance(profile, OpenAIModelProfile)
assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False


def test_cerebras_provider_need_api_key(env: TestEnv) -> None:
env.remove('CEREBRAS_API_KEY')
with pytest.raises(
UserError,
match=re.escape(
'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` '
'to use the Cerebras provider.'
),
):
CerebrasProvider()


def test_cerebras_provider_pass_openai_client() -> None:
openai_client = AsyncOpenAI(api_key='test_key', base_url='https://api.cerebras.ai/v1')
provider = CerebrasProvider(openai_client=openai_client)
assert provider.client == openai_client
Loading