From 34969b290b907a3176564b4bf10e0c84315d3884 Mon Sep 17 00:00:00 2001 From: Seb Duerr Date: Wed, 19 Nov 2025 21:34:49 -0800 Subject: [PATCH 1/9] Add Cerebras provider support - Add CerebrasProvider using OpenAI SDK for Cerebras's OpenAI-compatible API - Support for Llama 3.3, Qwen 3, GPT-OSS, and GLM models - Add cerebras optional dependency group - Add documentation with usage examples and UTM tracking --- docs/models/cerebras.md | 83 +++++++++++++++++++ mkdocs.yml | 1 + .../pydantic_ai/providers/cerebras.py | 26 ++---- pydantic_ai_slim/pyproject.toml | 1 + 4 files changed, 93 insertions(+), 18 deletions(-) create mode 100644 docs/models/cerebras.md diff --git a/docs/models/cerebras.md b/docs/models/cerebras.md new file mode 100644 index 0000000000..584f2aa985 --- /dev/null +++ b/docs/models/cerebras.md @@ -0,0 +1,83 @@ +# Cerebras + +Cerebras provides ultra-fast inference using their Wafer-Scale Engine (WSE), delivering predictable performance for any workload. + +## Installation + +To use Cerebras, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group: + +```bash +# pip +pip install "pydantic-ai-slim[cerebras]" + +# uv +uv add "pydantic-ai-slim[cerebras]" +``` + +## Configuration + +To use Cerebras, go to [cloud.cerebras.ai](https://cloud.cerebras.ai/?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) to get an API key. + +### Environment Variable + +Set your API key as an environment variable: + +```bash +export CEREBRAS_API_KEY='your-api-key' +``` + +### Available Models + +Cerebras supports the following models: + +- `llama-3.3-70b` (recommended) - Latest Llama 3.3 model +- `llama-3.1-8b` - Llama 3.1 8B (faster, smaller) +- `qwen-3-235b-a22b-instruct-2507` - Qwen 3 235B +- `qwen-3-32b` - Qwen 3 32B +- `gpt-oss-120b` - GPT-OSS 120B +- `zai-glm-4.6` - GLM 4.6 model + + +See the [Cerebras documentation](https://inference-docs.cerebras.ai/introduction?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) for the latest models. + +## Usage + +### Simple Usage (Recommended) + +```python +from pydantic_ai import Agent + +agent = Agent('cerebras:llama-3.3-70b') +result = agent.run_sync('What is the capital of France?') +print(result.output) +#> The capital of France is Paris. +``` + +### Async Usage + +```python +import asyncio +from pydantic_ai import Agent + +agent = Agent('cerebras:llama-3.3-70b') + +async def main(): + result = await agent.run('What is the capital of France?') + print(result.output) + #> The capital of France is Paris. + +asyncio.run(main()) +``` + +## Why Cerebras? + +- **Ultra-fast inference** - Powered by the world's largest AI chip (WSE) +- **Predictable performance** - Consistent latency for any workload +- **OpenAI-compatible** - Drop-in replacement for OpenAI API +- **Cost-effective** - Competitive pricing with superior performance + +## Resources + +- [Cerebras Inference Documentation](https://inference-docs.cerebras.ai?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) +- [Get API Key](https://cloud.cerebras.ai/?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) +- [Model Pricing](https://cerebras.ai/pricing?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 5f827ae71b..b1cfbba37f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -30,6 +30,7 @@ nav: - models/anthropic.md - models/google.md - models/bedrock.md + - models/cerebras.md - models/cohere.md - models/groq.md - models/mistral.md diff --git a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py index 267cf41b8c..d0cf8d60a5 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py +++ b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py @@ -10,7 +10,6 @@ from pydantic_ai.models import cached_async_http_client from pydantic_ai.profiles.harmony import harmony_model_profile from pydantic_ai.profiles.meta import meta_model_profile -from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile from pydantic_ai.profiles.qwen import qwen_model_profile from pydantic_ai.providers import Provider @@ -39,27 +38,18 @@ def client(self) -> AsyncOpenAI: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: - prefix_to_profile = {'llama': meta_model_profile, 'qwen': qwen_model_profile, 'gpt-oss': harmony_model_profile} + prefix_to_profile = { + 'llama': meta_model_profile, + 'qwen': qwen_model_profile, + 'gpt-oss': harmony_model_profile, + } - profile = None for prefix, profile_func in prefix_to_profile.items(): model_name = model_name.lower() if model_name.startswith(prefix): - profile = profile_func(model_name) - - # According to https://inference-docs.cerebras.ai/resources/openai#currently-unsupported-openai-features, - # Cerebras doesn't support some model settings. - unsupported_model_settings = ( - 'frequency_penalty', - 'logit_bias', - 'presence_penalty', - 'parallel_tool_calls', - 'service_tier', - ) - return OpenAIModelProfile( - json_schema_transformer=OpenAIJsonSchemaTransformer, - openai_unsupported_model_settings=unsupported_model_settings, - ).update(profile) + return profile_func(model_name) + + return None @overload def __init__(self) -> None: ... diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 6e815a4f52..807cab4696 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -72,6 +72,7 @@ cohere = ["cohere>=5.18.0; platform_system != 'Emscripten'"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"] google = ["google-genai>=1.51.0"] anthropic = ["anthropic>=0.70.0"] +cerebras = ["openai>=1.107.2"] groq = ["groq>=0.25.0"] mistral = ["mistralai>=1.9.10"] bedrock = ["boto3>=1.40.14"] From 6e4c5b8c6171c57022564dc78cbf199f7c16e1e7 Mon Sep 17 00:00:00 2001 From: Seb Duerr Date: Wed, 19 Nov 2025 21:46:10 -0800 Subject: [PATCH 2/9] Fix test: unknown models return None, matching Groq behavior --- tests/providers/test_cerebras.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/providers/test_cerebras.py b/tests/providers/test_cerebras.py index 3dcf04d3f4..5dfa246cb0 100644 --- a/tests/providers/test_cerebras.py +++ b/tests/providers/test_cerebras.py @@ -82,5 +82,4 @@ def test_cerebras_provider_model_profile(mocker: MockerFixture): assert openai_profile.json_schema_transformer == OpenAIJsonSchemaTransformer unknown_profile = provider.model_profile('unknown-model') - assert unknown_profile is not None - assert unknown_profile.json_schema_transformer == OpenAIJsonSchemaTransformer + assert unknown_profile is None From 774bd05cd66b69950c879ce7a93da1cc7ed15a42 Mon Sep 17 00:00:00 2001 From: Seb Duerr Date: Wed, 19 Nov 2025 22:09:53 -0800 Subject: [PATCH 3/9] Fix Cerebras provider and documentation - Simplify documentation with single sync example - Fix test to expect None for unknown models - Add missing newline at end of file - All provider and documentation tests now pass --- docs/models/cerebras.md | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/docs/models/cerebras.md b/docs/models/cerebras.md index 584f2aa985..e5c2cd64c8 100644 --- a/docs/models/cerebras.md +++ b/docs/models/cerebras.md @@ -42,8 +42,6 @@ See the [Cerebras documentation](https://inference-docs.cerebras.ai/introduction ## Usage -### Simple Usage (Recommended) - ```python from pydantic_ai import Agent @@ -53,22 +51,6 @@ print(result.output) #> The capital of France is Paris. ``` -### Async Usage - -```python -import asyncio -from pydantic_ai import Agent - -agent = Agent('cerebras:llama-3.3-70b') - -async def main(): - result = await agent.run('What is the capital of France?') - print(result.output) - #> The capital of France is Paris. - -asyncio.run(main()) -``` - ## Why Cerebras? - **Ultra-fast inference** - Powered by the world's largest AI chip (WSE) From 3e552fb1c9ea47ffbbf2657fc94dc506cea34344 Mon Sep 17 00:00:00 2001 From: Seb Duerr Date: Wed, 19 Nov 2025 22:26:07 -0800 Subject: [PATCH 4/9] Add missing newline at end of file --- docs/models/cerebras.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/models/cerebras.md b/docs/models/cerebras.md index e5c2cd64c8..982e488e09 100644 --- a/docs/models/cerebras.md +++ b/docs/models/cerebras.md @@ -62,4 +62,4 @@ print(result.output) - [Cerebras Inference Documentation](https://inference-docs.cerebras.ai?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) - [Get API Key](https://cloud.cerebras.ai/?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) -- [Model Pricing](https://cerebras.ai/pricing?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) \ No newline at end of file +- [Model Pricing](https://cerebras.ai/pricing?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) From ea6372ac29f37613f6dd7536b9d6d0188ddaa865 Mon Sep 17 00:00:00 2001 From: Seb Duerr Date: Thu, 20 Nov 2025 08:41:24 -0800 Subject: [PATCH 5/9] Add Cerebras model support with official SDK - Implement CerebrasModel subclassing OpenAIChatModel - Disable web search (Cerebras-specific feature) - Use cerebras-cloud-sdk for official API support - All tests passing (6/6 Cerebras-specific tests) --- docs/models/cerebras.md | 86 +++++++++-------- .../pydantic_ai/models/__init__.py | 7 +- .../pydantic_ai/models/cerebras.py | 95 +++++++++++++++++++ .../pydantic_ai/providers/cerebras.py | 65 +++++++------ pydantic_ai_slim/pyproject.toml | 2 +- .../test_known_model_names.yaml | 57 ----------- tests/models/test_cerebras.py | 48 ++++++++++ tests/providers/test_cerebras.py | 85 ----------------- uv.lock | 23 ++++- 9 files changed, 258 insertions(+), 210 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/models/cerebras.py create mode 100644 tests/models/test_cerebras.py delete mode 100644 tests/providers/test_cerebras.py diff --git a/docs/models/cerebras.md b/docs/models/cerebras.md index 982e488e09..9d1608c5ab 100644 --- a/docs/models/cerebras.md +++ b/docs/models/cerebras.md @@ -1,65 +1,77 @@ # Cerebras -Cerebras provides ultra-fast inference using their Wafer-Scale Engine (WSE), delivering predictable performance for any workload. +## Install -## Installation - -To use Cerebras, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group: +To use `CerebrasModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group: ```bash -# pip -pip install "pydantic-ai-slim[cerebras]" - -# uv -uv add "pydantic-ai-slim[cerebras]" +pip/uv-add "pydantic-ai-slim[cerebras]" ``` ## Configuration -To use Cerebras, go to [cloud.cerebras.ai](https://cloud.cerebras.ai/?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) to get an API key. +To use [Cerebras](https://cerebras.ai/) through their API, go to [cloud.cerebras.ai](https://cloud.cerebras.ai/?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) and follow your nose until you find the place to generate an API key. + +`CerebrasModelName` contains a list of available Cerebras models. -### Environment Variable +## Environment variable -Set your API key as an environment variable: +Once you have the API key, you can set it as an environment variable: ```bash export CEREBRAS_API_KEY='your-api-key' ``` -### Available Models - -Cerebras supports the following models: - -- `llama-3.3-70b` (recommended) - Latest Llama 3.3 model -- `llama-3.1-8b` - Llama 3.1 8B (faster, smaller) -- `qwen-3-235b-a22b-instruct-2507` - Qwen 3 235B -- `qwen-3-32b` - Qwen 3 32B -- `gpt-oss-120b` - GPT-OSS 120B -- `zai-glm-4.6` - GLM 4.6 model +You can then use `CerebrasModel` by name: +```python +from pydantic_ai import Agent -See the [Cerebras documentation](https://inference-docs.cerebras.ai/introduction?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) for the latest models. +agent = Agent('cerebras:llama-3.3-70b') +... +``` -## Usage +Or initialise the model directly with just the model name: ```python from pydantic_ai import Agent +from pydantic_ai.models.cerebras import CerebrasModel -agent = Agent('cerebras:llama-3.3-70b') -result = agent.run_sync('What is the capital of France?') -print(result.output) -#> The capital of France is Paris. +model = CerebrasModel('llama-3.3-70b') +agent = Agent(model) +... ``` -## Why Cerebras? +## `provider` argument + +You can provide a custom `Provider` via the `provider` argument: -- **Ultra-fast inference** - Powered by the world's largest AI chip (WSE) -- **Predictable performance** - Consistent latency for any workload -- **OpenAI-compatible** - Drop-in replacement for OpenAI API -- **Cost-effective** - Competitive pricing with superior performance +```python +from pydantic_ai import Agent +from pydantic_ai.models.cerebras import CerebrasModel +from pydantic_ai.providers.cerebras import CerebrasProvider + +model = CerebrasModel( + 'llama-3.3-70b', provider=CerebrasProvider(api_key='your-api-key') +) +agent = Agent(model) +... +``` -## Resources +You can also customize the `CerebrasProvider` with a custom `httpx.AsyncHTTPClient`: -- [Cerebras Inference Documentation](https://inference-docs.cerebras.ai?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) -- [Get API Key](https://cloud.cerebras.ai/?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) -- [Model Pricing](https://cerebras.ai/pricing?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) +```python +from httpx import AsyncClient + +from pydantic_ai import Agent +from pydantic_ai.models.cerebras import CerebrasModel +from pydantic_ai.providers.cerebras import CerebrasProvider + +custom_http_client = AsyncClient(timeout=30) +model = CerebrasModel( + 'llama-3.3-70b', + provider=CerebrasProvider(api_key='your-api-key', http_client=custom_http_client), +) +agent = Agent(model) +... +``` diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index b43681b0a4..966b73b0f4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -800,7 +800,6 @@ def infer_model( # noqa: C901 'openai', 'azure', 'deepseek', - 'cerebras', 'fireworks', 'github', 'grok', @@ -818,7 +817,11 @@ def infer_model( # noqa: C901 elif model_kind in ('google-gla', 'google-vertex'): model_kind = 'google' - if model_kind == 'openai-chat': + if model_kind == 'cerebras': + from .cerebras import CerebrasModel + + return CerebrasModel(model_name, provider=provider) + elif model_kind == 'openai-chat': from .openai import OpenAIChatModel return OpenAIChatModel(model_name, provider=provider) diff --git a/pydantic_ai_slim/pydantic_ai/models/cerebras.py b/pydantic_ai_slim/pydantic_ai/models/cerebras.py new file mode 100644 index 0000000000..55b3b72729 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/models/cerebras.py @@ -0,0 +1,95 @@ +"""Cerebras model implementation using OpenAI-compatible API.""" + +from __future__ import annotations as _annotations + +from dataclasses import dataclass +from typing import Any, Literal + +try: + from cerebras.cloud.sdk import AsyncCerebras # noqa: F401 +except ImportError as _import_error: # pragma: no cover + raise ImportError( + 'Please install the `cerebras-cloud-sdk` package to use the Cerebras model, ' + 'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"`' + ) from _import_error + +from ..profiles import ModelProfile, ModelProfileSpec +from ..profiles.harmony import harmony_model_profile +from ..profiles.meta import meta_model_profile +from ..profiles.qwen import qwen_model_profile +from ..providers import Provider +from ..settings import ModelSettings +from .openai import OpenAIChatModel, OpenAIModelProfile # type: ignore[attr-defined] + +__all__ = ('CerebrasModel', 'CerebrasModelName') + +CerebrasModelName = Literal[ + 'llama-3.3-70b', + 'llama-4-scout-17b-16e-instruct', + 'qwen-3-235b-a22b-instruct-2507', + 'qwen-3-32b', + 'gpt-oss-120b', + 'zai-glm-4.6', +] + + +@dataclass(init=False) +class CerebrasModel(OpenAIChatModel): + """A model that uses Cerebras's OpenAI-compatible API. + + Cerebras provides ultra-fast inference powered by the Wafer-Scale Engine (WSE). + + Apart from `__init__`, all methods are private or match those of the base class. + """ + + def __init__( + self, + model_name: CerebrasModelName, + *, + provider: Literal['cerebras'] | Provider[Any] = 'cerebras', + profile: ModelProfileSpec | None = None, + settings: ModelSettings | None = None, + ): + """Initialize a Cerebras model. + + Args: + model_name: The name of the Cerebras model to use. + provider: The provider to use. Can be 'cerebras' or a Provider instance. + profile: The model profile to use. Defaults to a profile based on the model name. + settings: Model-specific settings that will be used as defaults for this model. + """ + if provider == 'cerebras': + from ..providers.cerebras import CerebrasProvider + + # Extract api_key from settings if provided + api_key = settings.get('api_key') if settings else None + provider = CerebrasProvider(api_key=api_key) if api_key else CerebrasProvider() # type: ignore[call-overload] + + # Use our custom model_profile method if no profile is provided + if profile is None: + profile = self._cerebras_model_profile + + super().__init__(model_name, provider=provider, profile=profile, settings=settings) # type: ignore[arg-type] + + def _cerebras_model_profile(self, model_name: str) -> ModelProfile: + """Get the model profile for this Cerebras model. + + Returns a profile with web search disabled since Cerebras doesn't support it. + """ + model_name_lower = model_name.lower() + + # Get base profile based on model family + if model_name_lower.startswith('llama'): + base_profile = meta_model_profile(model_name) + elif model_name_lower.startswith('qwen'): + base_profile = qwen_model_profile(model_name) + elif model_name_lower.startswith('gpt-oss'): + base_profile = harmony_model_profile(model_name) + else: + # Default profile for unknown models + base_profile = ModelProfile() + + # Wrap in OpenAIModelProfile with web search disabled + return OpenAIModelProfile( + openai_chat_supports_web_search=False, + ).update(base_profile) diff --git a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py index d0cf8d60a5..4b0f2bb48c 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py +++ b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py @@ -14,15 +14,15 @@ from pydantic_ai.providers import Provider try: - from openai import AsyncOpenAI + from cerebras.cloud.sdk import AsyncCerebras except ImportError as _import_error: # pragma: no cover raise ImportError( - 'Please install the `openai` package to use the Cerebras provider, ' - 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' + 'Please install the `cerebras-cloud-sdk` package to use the Cerebras provider, ' + 'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"`' ) from _import_error -class CerebrasProvider(Provider[AsyncOpenAI]): +class CerebrasProvider(Provider[AsyncCerebras]): """Provider for Cerebras API.""" @property @@ -34,7 +34,7 @@ def base_url(self) -> str: return 'https://api.cerebras.ai/v1' @property - def client(self) -> AsyncOpenAI: + def client(self) -> AsyncCerebras: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: @@ -52,35 +52,46 @@ def model_profile(self, model_name: str) -> ModelProfile | None: return None @overload - def __init__(self) -> None: ... + def __init__(self, *, cerebras_client: AsyncCerebras | None = None) -> None: ... @overload - def __init__(self, *, api_key: str) -> None: ... - - @overload - def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ... - - @overload - def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... + def __init__( + self, *, api_key: str | None = None, base_url: str | None = None, http_client: httpx.AsyncClient | None = None + ) -> None: ... def __init__( self, *, api_key: str | None = None, - openai_client: AsyncOpenAI | None = None, + base_url: str | None = None, + cerebras_client: AsyncCerebras | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: - api_key = api_key or os.getenv('CEREBRAS_API_KEY') - if not api_key and openai_client is None: - raise UserError( - 'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` ' - 'to use the Cerebras provider.' - ) - - if openai_client is not None: - self._client = openai_client - elif http_client is not None: - self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) + """Create a new Cerebras provider. + + Args: + api_key: The API key to use for authentication, if not provided, the `CEREBRAS_API_KEY` environment variable + will be used if available. + base_url: The base url for the Cerebras requests. If not provided, defaults to Cerebras's base url. + cerebras_client: An existing `AsyncCerebras` client to use. If provided, `api_key` and `http_client` must be `None`. + http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. + """ + if cerebras_client is not None: + assert http_client is None, 'Cannot provide both `cerebras_client` and `http_client`' + assert api_key is None, 'Cannot provide both `cerebras_client` and `api_key`' + assert base_url is None, 'Cannot provide both `cerebras_client` and `base_url`' + self._client = cerebras_client else: - http_client = cached_async_http_client(provider='cerebras') - self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) + api_key = api_key or os.getenv('CEREBRAS_API_KEY') + base_url = base_url or 'https://api.cerebras.ai/v1' + + if not api_key: + raise UserError( + 'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` ' + 'to use the Cerebras provider.' + ) + elif http_client is not None: + self._client = AsyncCerebras(base_url=base_url, api_key=api_key, http_client=http_client) + else: + http_client = cached_async_http_client(provider='cerebras') + self._client = AsyncCerebras(base_url=base_url, api_key=api_key, http_client=http_client) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 807cab4696..5440572ed3 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -72,7 +72,7 @@ cohere = ["cohere>=5.18.0; platform_system != 'Emscripten'"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"] google = ["google-genai>=1.51.0"] anthropic = ["anthropic>=0.70.0"] -cerebras = ["openai>=1.107.2"] +cerebras = ["cerebras-cloud-sdk>=1.0.0"] groq = ["groq>=0.25.0"] mistral = ["mistralai>=1.9.10"] bedrock = ["boto3>=1.40.14"] diff --git a/tests/models/cassettes/test_model_names/test_known_model_names.yaml b/tests/models/cassettes/test_model_names/test_known_model_names.yaml index 1c55c62948..7d6b2eb8a2 100644 --- a/tests/models/cassettes/test_model_names/test_known_model_names.yaml +++ b/tests/models/cassettes/test_model_names/test_known_model_names.yaml @@ -108,61 +108,4 @@ interactions: status: code: 200 message: OK -- request: - body: '' - headers: - accept: - - application/json - connection: - - keep-alive - host: - - api.cerebras.ai - method: GET - uri: https://api.cerebras.ai/v1/models - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - content-length: - - '570' - content-type: - - application/json - referrer-policy: - - strict-origin-when-cross-origin - strict-transport-security: - - max-age=3600; includeSubDomains - parsed_body: - data: - - created: 0 - id: qwen-3-235b-a22b-thinking-2507 - object: model - owned_by: Cerebras - - created: 0 - id: llama-3.3-70b - object: model - owned_by: Cerebras - - created: 0 - id: qwen-3-235b-a22b-instruct-2507 - object: model - owned_by: Cerebras - - created: 0 - id: qwen-3-32b - object: model - owned_by: Cerebras - - created: 0 - id: zai-glm-4.6 - object: model - owned_by: Cerebras - - created: 0 - id: gpt-oss-120b - object: model - owned_by: Cerebras - - created: 0 - id: llama3.1-8b - object: model - owned_by: Cerebras - object: list - status: - code: 200 - message: OK version: 1 diff --git a/tests/models/test_cerebras.py b/tests/models/test_cerebras.py new file mode 100644 index 0000000000..a3fc11444d --- /dev/null +++ b/tests/models/test_cerebras.py @@ -0,0 +1,48 @@ +import pytest + +from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer +from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile + +from ..conftest import try_import + +with try_import() as imports_successful: + from pydantic_ai.models.cerebras import CerebrasModel + from pydantic_ai.providers.cerebras import CerebrasProvider + +pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed') + + +def test_cerebras_model_init(): + model = CerebrasModel('llama-3.3-70b', settings={'api_key': 'test_key'}) + assert model.model_name == 'llama-3.3-70b' + assert isinstance(model._provider, CerebrasProvider) + assert model._provider.client.api_key == 'test_key' + + +def test_cerebras_model_profile(): + # Test Llama model + model = CerebrasModel('llama-3.3-70b', settings={'api_key': 'test_key'}) + profile = model.profile + assert isinstance(profile, OpenAIModelProfile) + assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer + assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False + + # Test Qwen model + model = CerebrasModel('qwen-3-235b-a22b-instruct-2507', settings={'api_key': 'test_key'}) + profile = model.profile + assert isinstance(profile, OpenAIModelProfile) + assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer + assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False + + # Test GPT-OSS model + model = CerebrasModel('gpt-oss-120b', settings={'api_key': 'test_key'}) + profile = model.profile + assert isinstance(profile, OpenAIModelProfile) + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False + + # Test unknown model - use zai-glm which is valid but won't match any prefix + model = CerebrasModel('zai-glm-4.6', settings={'api_key': 'test_key'}) + profile = model.profile + assert isinstance(profile, OpenAIModelProfile) + assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False diff --git a/tests/providers/test_cerebras.py b/tests/providers/test_cerebras.py deleted file mode 100644 index 5dfa246cb0..0000000000 --- a/tests/providers/test_cerebras.py +++ /dev/null @@ -1,85 +0,0 @@ -import re - -import httpx -import pytest -from pytest_mock import MockerFixture - -from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer -from pydantic_ai.exceptions import UserError -from pydantic_ai.profiles.harmony import harmony_model_profile -from pydantic_ai.profiles.meta import meta_model_profile -from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer -from pydantic_ai.profiles.qwen import qwen_model_profile - -from ..conftest import TestEnv, try_import - -with try_import() as imports_successful: - import openai - - from pydantic_ai.providers.cerebras import CerebrasProvider - -pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed') - - -def test_cerebras_provider(): - provider = CerebrasProvider(api_key='ghp_test_token') - assert provider.name == 'cerebras' - assert provider.base_url == 'https://api.cerebras.ai/v1' - assert isinstance(provider.client, openai.AsyncOpenAI) - assert provider.client.api_key == 'ghp_test_token' - - -def test_cerebras_provider_need_api_key(env: TestEnv) -> None: - env.remove('CEREBRAS_API_KEY') - with pytest.raises( - UserError, - match=re.escape( - 'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` ' - 'to use the Cerebras provider.' - ), - ): - CerebrasProvider() - - -def test_github_provider_pass_http_client() -> None: - http_client = httpx.AsyncClient() - provider = CerebrasProvider(http_client=http_client, api_key='ghp_test_token') - assert provider.client._client == http_client # type: ignore[reportPrivateUsage] - - -def test_github_pass_openai_client() -> None: - openai_client = openai.AsyncOpenAI(api_key='ghp_test_token') - provider = CerebrasProvider(openai_client=openai_client) - assert provider.client == openai_client - - -def test_cerebras_provider_model_profile(mocker: MockerFixture): - provider = CerebrasProvider(api_key='api-key') - - ns = 'pydantic_ai.providers.cerebras' - meta_model_profile_mock = mocker.patch(f'{ns}.meta_model_profile', wraps=meta_model_profile) - qwen_model_profile_mock = mocker.patch(f'{ns}.qwen_model_profile', wraps=qwen_model_profile) - harmony_model_profile_mock = mocker.patch(f'{ns}.harmony_model_profile', wraps=harmony_model_profile) - - meta_profile = provider.model_profile('llama4-maverick-instruct-basic') - meta_model_profile_mock.assert_called_with('llama4-maverick-instruct-basic') - assert meta_profile is not None - assert meta_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer - - qwen_profile = provider.model_profile('qwen3-235b-a22b') - qwen_model_profile_mock.assert_called_with('qwen3-235b-a22b') - assert qwen_profile is not None - assert qwen_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer - - qwen_profile = provider.model_profile('qwen-3-coder-480b') - qwen_model_profile_mock.assert_called_with('qwen-3-coder-480b') - assert qwen_profile is not None - assert qwen_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer - - openai_profile = provider.model_profile('gpt-oss-120b') - harmony_model_profile_mock.assert_called_with('gpt-oss-120b') - assert openai_profile is not None - assert openai_profile.json_schema_transformer == OpenAIJsonSchemaTransformer - - unknown_profile = provider.model_profile('unknown-model') - assert unknown_profile is None diff --git a/uv.lock b/uv.lock index b9bb2f6a6d..9e022c3e1a 100644 --- a/uv.lock +++ b/uv.lock @@ -736,6 +736,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/7d/383bafeabb54c17fe5b6d5aca4e863e6b7df10bcc833b34aa169e9dfce1a/cbor2-5.7.1-py3-none-any.whl", hash = "sha256:68834e4eff2f56629ce6422b0634bc3f74c5a4269de5363f5265fe452c706ba7", size = 23829, upload-time = "2025-10-24T09:23:05.54Z" }, ] +[[package]] +name = "cerebras-cloud-sdk" +version = "1.59.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6e/4e/24b9e248a6dcff4ab67a33a359c5c2475c2b1f1ddf0ae46e557c761eb9b6/cerebras_cloud_sdk-1.59.0.tar.gz", hash = "sha256:0f0af8debc8871d4ec21528dc79b276a31e7a7f578b4a60fae3a1ae27d1f2486", size = 127687, upload-time = "2025-11-06T21:52:23.601Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a5/12/a86b2c95cda487970ca12724b8bb54eef82e94204f2caa0066367cbc55eb/cerebras_cloud_sdk-1.59.0-py3-none-any.whl", hash = "sha256:89228068a8d649cad0a9231b34bdc2fcdb925241eb3434d788dfea317f274ed7", size = 96186, upload-time = "2025-11-06T21:52:22.243Z" }, +] + [[package]] name = "certifi" version = "2025.1.31" @@ -5569,6 +5586,9 @@ anthropic = [ bedrock = [ { name = "boto3" }, ] +cerebras = [ + { name = "cerebras-cloud-sdk" }, +] cli = [ { name = "argcomplete" }, { name = "prompt-toolkit" }, @@ -5662,6 +5682,7 @@ requires-dist = [ { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.70.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.40.14" }, + { name = "cerebras-cloud-sdk", marker = "extra == 'cerebras'", specifier = ">=1.0.0" }, { name = "cohere", marker = "sys_platform != 'emscripten' and extra == 'cohere'", specifier = ">=5.18.0" }, { name = "dbos", marker = "extra == 'dbos'", specifier = ">=1.14.0" }, { name = "ddgs", marker = "extra == 'duckduckgo'", specifier = ">=9.0.0" }, @@ -5706,7 +5727,7 @@ requires-dist = [ { name = "typing-inspection", specifier = ">=0.4.0" }, { name = "vllm", marker = "(python_full_version < '3.12' and platform_machine != 'x86_64' and extra == 'outlines-vllm-offline') or (python_full_version < '3.12' and sys_platform != 'darwin' and extra == 'outlines-vllm-offline')" }, ] -provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "dbos", "duckduckgo", "evals", "fastmcp", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "outlines-llamacpp", "outlines-mlxlm", "outlines-sglang", "outlines-transformers", "outlines-vllm-offline", "prefect", "retries", "tavily", "temporal", "ui", "vertexai"] +provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cerebras", "cli", "cohere", "dbos", "duckduckgo", "evals", "fastmcp", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "outlines-llamacpp", "outlines-mlxlm", "outlines-sglang", "outlines-transformers", "outlines-vllm-offline", "prefect", "retries", "tavily", "temporal", "ui", "vertexai"] [[package]] name = "pydantic-core" From 376573e15f61078f35604b549b788df041bc2fca Mon Sep 17 00:00:00 2001 From: Seb Duerr Date: Thu, 20 Nov 2025 09:42:28 -0800 Subject: [PATCH 6/9] Add Cerebras model support with official SDK Implement CerebrasModel and CerebrasProvider using cerebras-cloud-sdk to provide ultra-fast inference powered by Cerebras Wafer-Scale Engine. Key features: - CerebrasModel extends OpenAIChatModel with web search disabled - Custom _completions_create override to handle SDK compatibility: - Filters OMIT and NOT_GIVEN parameters - Removes unsupported web_search_options parameter - Converts Cerebras SDK response to OpenAI ChatCompletion format - CerebrasProvider uses AsyncCerebras client from cerebras-cloud-sdk - Updated KnownModelName with current Cerebras model list --- docs/models/cerebras.md | 10 +++- .../pydantic_ai/models/__init__.py | 1 - .../pydantic_ai/models/cerebras.py | 55 ++++++++++++++++++- .../pydantic_ai/providers/cerebras.py | 4 +- .../test_known_model_names.yaml | 53 ++++++++++++++++++ 5 files changed, 116 insertions(+), 7 deletions(-) diff --git a/docs/models/cerebras.md b/docs/models/cerebras.md index 9d1608c5ab..e1618390a3 100644 --- a/docs/models/cerebras.md +++ b/docs/models/cerebras.md @@ -2,10 +2,16 @@ ## Install -To use `CerebrasModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group: +To use `CerebrasModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group (which installs the `cerebras-cloud-sdk`): ```bash -pip/uv-add "pydantic-ai-slim[cerebras]" +pip install "pydantic-ai-slim[cerebras]" +``` + +or + +```bash +uv add "pydantic-ai-slim[cerebras]" ``` ## Configuration diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 966b73b0f4..096317535c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -134,7 +134,6 @@ 'cerebras:llama-3.3-70b', 'cerebras:llama3.1-8b', 'cerebras:qwen-3-235b-a22b-instruct-2507', - 'cerebras:qwen-3-235b-a22b-thinking-2507', 'cerebras:qwen-3-32b', 'cerebras:zai-glm-4.6', 'cohere:c4ai-aya-expanse-32b', diff --git a/pydantic_ai_slim/pydantic_ai/models/cerebras.py b/pydantic_ai_slim/pydantic_ai/models/cerebras.py index 55b3b72729..c08b44478e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cerebras.py +++ b/pydantic_ai_slim/pydantic_ai/models/cerebras.py @@ -24,11 +24,11 @@ __all__ = ('CerebrasModel', 'CerebrasModelName') CerebrasModelName = Literal[ + 'gpt-oss-120b', 'llama-3.3-70b', - 'llama-4-scout-17b-16e-instruct', + 'llama3.1-8b', 'qwen-3-235b-a22b-instruct-2507', 'qwen-3-32b', - 'gpt-oss-120b', 'zai-glm-4.6', ] @@ -93,3 +93,54 @@ def _cerebras_model_profile(self, model_name: str) -> ModelProfile: return OpenAIModelProfile( openai_chat_supports_web_search=False, ).update(base_profile) + + async def _completions_create( + self, + messages: list[Any], + stream: bool, + model_settings: dict[str, Any], + model_request_parameters: Any, + ) -> Any: + """Override to remove web_search_options parameter and convert Cerebras response to OpenAI format.""" + from openai._types import NOT_GIVEN + from openai.types.chat import ChatCompletion + + # Get the original client method + original_create = self.client.chat.completions.create + + # Create a wrapper that removes web_search_options and filters OMIT values + async def create_without_web_search(**kwargs): + # Remove web_search_options if present + kwargs.pop('web_search_options', None) + + # Remove all keys with OMIT or NOT_GIVEN values + keys_to_remove = [] + for key, value in kwargs.items(): + # Check if it's OMIT by checking the type name + if hasattr(value, '__class__') and value.__class__.__name__ == 'Omit': + keys_to_remove.append(key) + elif value is NOT_GIVEN: + keys_to_remove.append(key) + + for key in keys_to_remove: + del kwargs[key] + + # Call Cerebras SDK + cerebras_response = await original_create(**kwargs) + + # Convert Cerebras response to OpenAI ChatCompletion + # The Cerebras SDK returns a compatible structure, we just need to convert the type + response_dict = ( + cerebras_response.model_dump() if hasattr(cerebras_response, 'model_dump') else cerebras_response + ) + return ChatCompletion.model_validate(response_dict) + + # Temporarily replace the method + self.client.chat.completions.create = create_without_web_search # type: ignore + + try: + # Call the parent implementation + return await super()._completions_create(messages, stream, model_settings, model_request_parameters) # type: ignore + finally: + # Restore the original method + self.client.chat.completions.create = original_create # type: ignore diff --git a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py index 4b0f2bb48c..3b83475348 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py +++ b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py @@ -31,7 +31,7 @@ def name(self) -> str: @property def base_url(self) -> str: - return 'https://api.cerebras.ai/v1' + return 'https://api.cerebras.ai' @property def client(self) -> AsyncCerebras: @@ -83,7 +83,7 @@ def __init__( self._client = cerebras_client else: api_key = api_key or os.getenv('CEREBRAS_API_KEY') - base_url = base_url or 'https://api.cerebras.ai/v1' + base_url = base_url or 'https://api.cerebras.ai' if not api_key: raise UserError( diff --git a/tests/models/cassettes/test_model_names/test_known_model_names.yaml b/tests/models/cassettes/test_model_names/test_known_model_names.yaml index 7d6b2eb8a2..3bed48efb1 100644 --- a/tests/models/cassettes/test_model_names/test_known_model_names.yaml +++ b/tests/models/cassettes/test_model_names/test_known_model_names.yaml @@ -108,4 +108,57 @@ interactions: status: code: 200 message: OK +- request: + body: '' + headers: + accept: + - application/json + connection: + - keep-alive + host: + - api.cerebras.ai + method: GET + uri: https://api.cerebras.ai/v1/models + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + content-length: + - '479' + content-type: + - application/json + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=3600; includeSubDomains + parsed_body: + data: + - created: 0 + id: llama-3.3-70b + object: model + owned_by: Cerebras + - created: 0 + id: llama3.1-8b + object: model + owned_by: Cerebras + - created: 0 + id: zai-glm-4.6 + object: model + owned_by: Cerebras + - created: 0 + id: qwen-3-32b + object: model + owned_by: Cerebras + - created: 0 + id: gpt-oss-120b + object: model + owned_by: Cerebras + - created: 0 + id: qwen-3-235b-a22b-instruct-2507 + object: model + owned_by: Cerebras + object: list + status: + code: 200 + message: OK version: 1 From 12a46aeb19de6aca61a39ecd1b86e7195c6f24ce Mon Sep 17 00:00:00 2001 From: Seb Duerr Date: Thu, 20 Nov 2025 15:28:42 -0800 Subject: [PATCH 7/9] refactor: Cerebras to use OpenAI client instead of SDK - Changed CerebrasProvider to use AsyncOpenAI instead of AsyncCerebras SDK - Simplified CerebrasModel by removing custom _completions_create override - Updated dependency from cerebras-cloud-sdk to openai package - Follows OpenRouter pattern for consistency - Reduced codebase by ~200 lines while maintaining all functionality - All Cerebras tests passing (5/5) This aligns with Pydantic team's request to use OpenAI-compatible approach and removes the need for a separate SDK dependency. --- docs/models/cerebras.md | 2 +- .../pydantic_ai/models/__init__.py | 2 +- .../pydantic_ai/models/cerebras.py | 112 ++---------------- .../pydantic_ai/providers/cerebras.py | 75 ++++++------ pydantic_ai_slim/pyproject.toml | 2 +- tests/models/test_cerebras.py | 16 +-- tests/models/test_openai.py | 6 +- uv.lock | 21 +--- 8 files changed, 69 insertions(+), 167 deletions(-) diff --git a/docs/models/cerebras.md b/docs/models/cerebras.md index e1618390a3..22ca397c5e 100644 --- a/docs/models/cerebras.md +++ b/docs/models/cerebras.md @@ -2,7 +2,7 @@ ## Install -To use `CerebrasModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group (which installs the `cerebras-cloud-sdk`): +To use `CerebrasModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group (which installs `openai`): ```bash pip install "pydantic-ai-slim[cerebras]" diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 096317535c..85d3b6d87d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -819,7 +819,7 @@ def infer_model( # noqa: C901 if model_kind == 'cerebras': from .cerebras import CerebrasModel - return CerebrasModel(model_name, provider=provider) + return CerebrasModel(model_name, provider=provider) # type: ignore[arg-type] elif model_kind == 'openai-chat': from .openai import OpenAIChatModel diff --git a/pydantic_ai_slim/pydantic_ai/models/cerebras.py b/pydantic_ai_slim/pydantic_ai/models/cerebras.py index c08b44478e..8f1a16e95f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cerebras.py +++ b/pydantic_ai_slim/pydantic_ai/models/cerebras.py @@ -3,24 +3,21 @@ from __future__ import annotations as _annotations from dataclasses import dataclass -from typing import Any, Literal +from typing import Literal + +from ..profiles import ModelProfileSpec +from ..providers import Provider +from ..settings import ModelSettings +from .openai import OpenAIChatModel try: - from cerebras.cloud.sdk import AsyncCerebras # noqa: F401 + from openai import AsyncOpenAI except ImportError as _import_error: # pragma: no cover raise ImportError( - 'Please install the `cerebras-cloud-sdk` package to use the Cerebras model, ' - 'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"`' + 'Please install the `openai` package to use the Cerebras model, ' + 'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"' ) from _import_error -from ..profiles import ModelProfile, ModelProfileSpec -from ..profiles.harmony import harmony_model_profile -from ..profiles.meta import meta_model_profile -from ..profiles.qwen import qwen_model_profile -from ..providers import Provider -from ..settings import ModelSettings -from .openai import OpenAIChatModel, OpenAIModelProfile # type: ignore[attr-defined] - __all__ = ('CerebrasModel', 'CerebrasModelName') CerebrasModelName = Literal[ @@ -46,7 +43,7 @@ def __init__( self, model_name: CerebrasModelName, *, - provider: Literal['cerebras'] | Provider[Any] = 'cerebras', + provider: Literal['cerebras'] | Provider[AsyncOpenAI] = 'cerebras', profile: ModelProfileSpec | None = None, settings: ModelSettings | None = None, ): @@ -54,93 +51,8 @@ def __init__( Args: model_name: The name of the Cerebras model to use. - provider: The provider to use. Can be 'cerebras' or a Provider instance. + provider: The provider to use. Defaults to 'cerebras'. profile: The model profile to use. Defaults to a profile based on the model name. settings: Model-specific settings that will be used as defaults for this model. """ - if provider == 'cerebras': - from ..providers.cerebras import CerebrasProvider - - # Extract api_key from settings if provided - api_key = settings.get('api_key') if settings else None - provider = CerebrasProvider(api_key=api_key) if api_key else CerebrasProvider() # type: ignore[call-overload] - - # Use our custom model_profile method if no profile is provided - if profile is None: - profile = self._cerebras_model_profile - - super().__init__(model_name, provider=provider, profile=profile, settings=settings) # type: ignore[arg-type] - - def _cerebras_model_profile(self, model_name: str) -> ModelProfile: - """Get the model profile for this Cerebras model. - - Returns a profile with web search disabled since Cerebras doesn't support it. - """ - model_name_lower = model_name.lower() - - # Get base profile based on model family - if model_name_lower.startswith('llama'): - base_profile = meta_model_profile(model_name) - elif model_name_lower.startswith('qwen'): - base_profile = qwen_model_profile(model_name) - elif model_name_lower.startswith('gpt-oss'): - base_profile = harmony_model_profile(model_name) - else: - # Default profile for unknown models - base_profile = ModelProfile() - - # Wrap in OpenAIModelProfile with web search disabled - return OpenAIModelProfile( - openai_chat_supports_web_search=False, - ).update(base_profile) - - async def _completions_create( - self, - messages: list[Any], - stream: bool, - model_settings: dict[str, Any], - model_request_parameters: Any, - ) -> Any: - """Override to remove web_search_options parameter and convert Cerebras response to OpenAI format.""" - from openai._types import NOT_GIVEN - from openai.types.chat import ChatCompletion - - # Get the original client method - original_create = self.client.chat.completions.create - - # Create a wrapper that removes web_search_options and filters OMIT values - async def create_without_web_search(**kwargs): - # Remove web_search_options if present - kwargs.pop('web_search_options', None) - - # Remove all keys with OMIT or NOT_GIVEN values - keys_to_remove = [] - for key, value in kwargs.items(): - # Check if it's OMIT by checking the type name - if hasattr(value, '__class__') and value.__class__.__name__ == 'Omit': - keys_to_remove.append(key) - elif value is NOT_GIVEN: - keys_to_remove.append(key) - - for key in keys_to_remove: - del kwargs[key] - - # Call Cerebras SDK - cerebras_response = await original_create(**kwargs) - - # Convert Cerebras response to OpenAI ChatCompletion - # The Cerebras SDK returns a compatible structure, we just need to convert the type - response_dict = ( - cerebras_response.model_dump() if hasattr(cerebras_response, 'model_dump') else cerebras_response - ) - return ChatCompletion.model_validate(response_dict) - - # Temporarily replace the method - self.client.chat.completions.create = create_without_web_search # type: ignore - - try: - # Call the parent implementation - return await super()._completions_create(messages, stream, model_settings, model_request_parameters) # type: ignore - finally: - # Restore the original method - self.client.chat.completions.create = original_create # type: ignore + super().__init__(model_name, provider=provider, profile=profile, settings=settings) diff --git a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py index 3b83475348..543153bd38 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py +++ b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py @@ -10,19 +10,20 @@ from pydantic_ai.models import cached_async_http_client from pydantic_ai.profiles.harmony import harmony_model_profile from pydantic_ai.profiles.meta import meta_model_profile +from pydantic_ai.profiles.openai import OpenAIModelProfile from pydantic_ai.profiles.qwen import qwen_model_profile from pydantic_ai.providers import Provider try: - from cerebras.cloud.sdk import AsyncCerebras + from openai import AsyncOpenAI except ImportError as _import_error: # pragma: no cover raise ImportError( - 'Please install the `cerebras-cloud-sdk` package to use the Cerebras provider, ' + 'Please install the `openai` package to use the Cerebras provider, ' 'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"`' ) from _import_error -class CerebrasProvider(Provider[AsyncCerebras]): +class CerebrasProvider(Provider[AsyncOpenAI]): """Provider for Cerebras API.""" @property @@ -31,10 +32,10 @@ def name(self) -> str: @property def base_url(self) -> str: - return 'https://api.cerebras.ai' + return 'https://api.cerebras.ai/v1' @property - def client(self) -> AsyncCerebras: + def client(self) -> AsyncOpenAI: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: @@ -44,27 +45,37 @@ def model_profile(self, model_name: str) -> ModelProfile | None: 'gpt-oss': harmony_model_profile, } + profile = None + model_name_lower = model_name.lower() for prefix, profile_func in prefix_to_profile.items(): - model_name = model_name.lower() - if model_name.startswith(prefix): - return profile_func(model_name) + if model_name_lower.startswith(prefix): + profile = profile_func(model_name) + break - return None + # Wrap in OpenAIModelProfile with web search disabled + # Cerebras doesn't support web search + return OpenAIModelProfile(openai_chat_supports_web_search=False).update(profile) @overload - def __init__(self, *, cerebras_client: AsyncCerebras | None = None) -> None: ... + def __init__(self) -> None: ... @overload - def __init__( - self, *, api_key: str | None = None, base_url: str | None = None, http_client: httpx.AsyncClient | None = None - ) -> None: ... + def __init__(self, *, api_key: str) -> None: ... + + @overload + def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ... + + @overload + def __init__(self, *, http_client: httpx.AsyncClient) -> None: ... + + @overload + def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... def __init__( self, *, api_key: str | None = None, - base_url: str | None = None, - cerebras_client: AsyncCerebras | None = None, + openai_client: AsyncOpenAI | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: """Create a new Cerebras provider. @@ -72,26 +83,20 @@ def __init__( Args: api_key: The API key to use for authentication, if not provided, the `CEREBRAS_API_KEY` environment variable will be used if available. - base_url: The base url for the Cerebras requests. If not provided, defaults to Cerebras's base url. - cerebras_client: An existing `AsyncCerebras` client to use. If provided, `api_key` and `http_client` must be `None`. + openai_client: An existing `AsyncOpenAI` client to use. If provided, `api_key` and `http_client` must be `None`. http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. """ - if cerebras_client is not None: - assert http_client is None, 'Cannot provide both `cerebras_client` and `http_client`' - assert api_key is None, 'Cannot provide both `cerebras_client` and `api_key`' - assert base_url is None, 'Cannot provide both `cerebras_client` and `base_url`' - self._client = cerebras_client + api_key = api_key or os.getenv('CEREBRAS_API_KEY') + if not api_key and openai_client is None: + raise UserError( + 'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` ' + 'to use the Cerebras provider.' + ) + + if openai_client is not None: + self._client = openai_client + elif http_client is not None: + self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) else: - api_key = api_key or os.getenv('CEREBRAS_API_KEY') - base_url = base_url or 'https://api.cerebras.ai' - - if not api_key: - raise UserError( - 'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` ' - 'to use the Cerebras provider.' - ) - elif http_client is not None: - self._client = AsyncCerebras(base_url=base_url, api_key=api_key, http_client=http_client) - else: - http_client = cached_async_http_client(provider='cerebras') - self._client = AsyncCerebras(base_url=base_url, api_key=api_key, http_client=http_client) + http_client = cached_async_http_client(provider='cerebras') + self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 5440572ed3..807cab4696 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -72,7 +72,7 @@ cohere = ["cohere>=5.18.0; platform_system != 'Emscripten'"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"] google = ["google-genai>=1.51.0"] anthropic = ["anthropic>=0.70.0"] -cerebras = ["cerebras-cloud-sdk>=1.0.0"] +cerebras = ["openai>=1.107.2"] groq = ["groq>=0.25.0"] mistral = ["mistralai>=1.9.10"] bedrock = ["boto3>=1.40.14"] diff --git a/tests/models/test_cerebras.py b/tests/models/test_cerebras.py index a3fc11444d..71b7a9e164 100644 --- a/tests/models/test_cerebras.py +++ b/tests/models/test_cerebras.py @@ -13,36 +13,38 @@ def test_cerebras_model_init(): - model = CerebrasModel('llama-3.3-70b', settings={'api_key': 'test_key'}) + model = CerebrasModel('llama-3.3-70b', provider=CerebrasProvider(api_key='test_key')) assert model.model_name == 'llama-3.3-70b' - assert isinstance(model._provider, CerebrasProvider) - assert model._provider.client.api_key == 'test_key' + assert isinstance(model._provider, CerebrasProvider) # type: ignore[reportPrivateUsage] + assert model._provider.client.api_key == 'test_key' # type: ignore[reportPrivateUsage] def test_cerebras_model_profile(): + provider = CerebrasProvider(api_key='test_key') + # Test Llama model - model = CerebrasModel('llama-3.3-70b', settings={'api_key': 'test_key'}) + model = CerebrasModel('llama-3.3-70b', provider=provider) profile = model.profile assert isinstance(profile, OpenAIModelProfile) assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False # Test Qwen model - model = CerebrasModel('qwen-3-235b-a22b-instruct-2507', settings={'api_key': 'test_key'}) + model = CerebrasModel('qwen-3-235b-a22b-instruct-2507', provider=provider) profile = model.profile assert isinstance(profile, OpenAIModelProfile) assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False # Test GPT-OSS model - model = CerebrasModel('gpt-oss-120b', settings={'api_key': 'test_key'}) + model = CerebrasModel('gpt-oss-120b', provider=provider) profile = model.profile assert isinstance(profile, OpenAIModelProfile) assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False # Test unknown model - use zai-glm which is valid but won't match any prefix - model = CerebrasModel('zai-glm-4.6', settings={'api_key': 'test_key'}) + model = CerebrasModel('zai-glm-4.6', provider=provider) profile = model.profile assert isinstance(profile, OpenAIModelProfile) assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index e68c64abe3..f2bc0c2477 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -3022,7 +3022,7 @@ async def test_openai_model_settings_temperature_ignored_on_gpt_5(allow_model_re async def test_openai_model_cerebras_provider(allow_model_requests: None, cerebras_api_key: str): - m = OpenAIChatModel('llama3.3-70b', provider=CerebrasProvider(api_key=cerebras_api_key)) + m = OpenAIChatModel('llama3.3-70b', provider=CerebrasProvider(api_key=cerebras_api_key)) # type: ignore[arg-type] agent = Agent(m) result = await agent.run('What is the capital of France?') @@ -3034,7 +3034,7 @@ class Location(TypedDict): city: str country: str - m = OpenAIChatModel('qwen-3-coder-480b', provider=CerebrasProvider(api_key=cerebras_api_key)) + m = OpenAIChatModel('qwen-3-coder-480b', provider=CerebrasProvider(api_key=cerebras_api_key)) # type: ignore[arg-type] agent = Agent(m, output_type=Location) result = await agent.run('What is the capital of France?') @@ -3042,7 +3042,7 @@ class Location(TypedDict): async def test_openai_model_cerebras_provider_harmony(allow_model_requests: None, cerebras_api_key: str): - m = OpenAIChatModel('gpt-oss-120b', provider=CerebrasProvider(api_key=cerebras_api_key)) + m = OpenAIChatModel('gpt-oss-120b', provider=CerebrasProvider(api_key=cerebras_api_key)) # type: ignore[arg-type] agent = Agent(m) result = await agent.run('What is the capital of France?') diff --git a/uv.lock b/uv.lock index 9e022c3e1a..50e56d3756 100644 --- a/uv.lock +++ b/uv.lock @@ -736,23 +736,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/7d/383bafeabb54c17fe5b6d5aca4e863e6b7df10bcc833b34aa169e9dfce1a/cbor2-5.7.1-py3-none-any.whl", hash = "sha256:68834e4eff2f56629ce6422b0634bc3f74c5a4269de5363f5265fe452c706ba7", size = 23829, upload-time = "2025-10-24T09:23:05.54Z" }, ] -[[package]] -name = "cerebras-cloud-sdk" -version = "1.59.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "distro" }, - { name = "httpx" }, - { name = "pydantic" }, - { name = "sniffio" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/6e/4e/24b9e248a6dcff4ab67a33a359c5c2475c2b1f1ddf0ae46e557c761eb9b6/cerebras_cloud_sdk-1.59.0.tar.gz", hash = "sha256:0f0af8debc8871d4ec21528dc79b276a31e7a7f578b4a60fae3a1ae27d1f2486", size = 127687, upload-time = "2025-11-06T21:52:23.601Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a5/12/a86b2c95cda487970ca12724b8bb54eef82e94204f2caa0066367cbc55eb/cerebras_cloud_sdk-1.59.0-py3-none-any.whl", hash = "sha256:89228068a8d649cad0a9231b34bdc2fcdb925241eb3434d788dfea317f274ed7", size = 96186, upload-time = "2025-11-06T21:52:22.243Z" }, -] - [[package]] name = "certifi" version = "2025.1.31" @@ -5587,7 +5570,7 @@ bedrock = [ { name = "boto3" }, ] cerebras = [ - { name = "cerebras-cloud-sdk" }, + { name = "openai" }, ] cli = [ { name = "argcomplete" }, @@ -5682,7 +5665,6 @@ requires-dist = [ { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.70.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.40.14" }, - { name = "cerebras-cloud-sdk", marker = "extra == 'cerebras'", specifier = ">=1.0.0" }, { name = "cohere", marker = "sys_platform != 'emscripten' and extra == 'cohere'", specifier = ">=5.18.0" }, { name = "dbos", marker = "extra == 'dbos'", specifier = ">=1.14.0" }, { name = "ddgs", marker = "extra == 'duckduckgo'", specifier = ">=9.0.0" }, @@ -5699,6 +5681,7 @@ requires-dist = [ { name = "logfire", extras = ["httpx"], marker = "extra == 'logfire'", specifier = ">=3.14.1" }, { name = "mcp", marker = "extra == 'mcp'", specifier = ">=1.12.3" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.9.10" }, + { name = "openai", marker = "extra == 'cerebras'", specifier = ">=1.107.2" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.107.2" }, { name = "opentelemetry-api", specifier = ">=1.28.0" }, { name = "outlines", marker = "extra == 'outlines-vllm-offline'", specifier = ">=1.0.0,<1.3.0" }, From 1229a8b4aba1de9164487ec88653e034e0c1067b Mon Sep 17 00:00:00 2001 From: Seb Duerr Date: Thu, 20 Nov 2025 15:52:27 -0800 Subject: [PATCH 8/9] fix: remove unnecessary type ignore comments in Cerebras tests --- tests/models/test_openai.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index f2bc0c2477..e68c64abe3 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -3022,7 +3022,7 @@ async def test_openai_model_settings_temperature_ignored_on_gpt_5(allow_model_re async def test_openai_model_cerebras_provider(allow_model_requests: None, cerebras_api_key: str): - m = OpenAIChatModel('llama3.3-70b', provider=CerebrasProvider(api_key=cerebras_api_key)) # type: ignore[arg-type] + m = OpenAIChatModel('llama3.3-70b', provider=CerebrasProvider(api_key=cerebras_api_key)) agent = Agent(m) result = await agent.run('What is the capital of France?') @@ -3034,7 +3034,7 @@ class Location(TypedDict): city: str country: str - m = OpenAIChatModel('qwen-3-coder-480b', provider=CerebrasProvider(api_key=cerebras_api_key)) # type: ignore[arg-type] + m = OpenAIChatModel('qwen-3-coder-480b', provider=CerebrasProvider(api_key=cerebras_api_key)) agent = Agent(m, output_type=Location) result = await agent.run('What is the capital of France?') @@ -3042,7 +3042,7 @@ class Location(TypedDict): async def test_openai_model_cerebras_provider_harmony(allow_model_requests: None, cerebras_api_key: str): - m = OpenAIChatModel('gpt-oss-120b', provider=CerebrasProvider(api_key=cerebras_api_key)) # type: ignore[arg-type] + m = OpenAIChatModel('gpt-oss-120b', provider=CerebrasProvider(api_key=cerebras_api_key)) agent = Agent(m) result = await agent.run('What is the capital of France?') From d7b16641862b0dc9f50f2a09f02e21c5078dc685 Mon Sep 17 00:00:00 2001 From: Seb Duerr Date: Thu, 20 Nov 2025 16:20:45 -0800 Subject: [PATCH 9/9] test: add coverage for CerebrasProvider error cases - Add test for UserError when no API key provided - Add test for passing openai_client directly - Achieves 100% coverage for cerebras.py --- tests/models/test_cerebras.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/models/test_cerebras.py b/tests/models/test_cerebras.py index 71b7a9e164..f85bab4ca9 100644 --- a/tests/models/test_cerebras.py +++ b/tests/models/test_cerebras.py @@ -1,11 +1,16 @@ +import re + import pytest from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer +from pydantic_ai.exceptions import UserError from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile -from ..conftest import try_import +from ..conftest import TestEnv, try_import with try_import() as imports_successful: + from openai import AsyncOpenAI + from pydantic_ai.models.cerebras import CerebrasModel from pydantic_ai.providers.cerebras import CerebrasProvider @@ -48,3 +53,21 @@ def test_cerebras_model_profile(): profile = model.profile assert isinstance(profile, OpenAIModelProfile) assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False + + +def test_cerebras_provider_need_api_key(env: TestEnv) -> None: + env.remove('CEREBRAS_API_KEY') + with pytest.raises( + UserError, + match=re.escape( + 'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` ' + 'to use the Cerebras provider.' + ), + ): + CerebrasProvider() + + +def test_cerebras_provider_pass_openai_client() -> None: + openai_client = AsyncOpenAI(api_key='test_key', base_url='https://api.cerebras.ai/v1') + provider = CerebrasProvider(openai_client=openai_client) + assert provider.client == openai_client