diff --git a/docs/models/cerebras.md b/docs/models/cerebras.md new file mode 100644 index 0000000000..22ca397c5e --- /dev/null +++ b/docs/models/cerebras.md @@ -0,0 +1,83 @@ +# Cerebras + +## Install + +To use `CerebrasModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cerebras` optional group (which installs `openai`): + +```bash +pip install "pydantic-ai-slim[cerebras]" +``` + +or + +```bash +uv add "pydantic-ai-slim[cerebras]" +``` + +## Configuration + +To use [Cerebras](https://cerebras.ai/) through their API, go to [cloud.cerebras.ai](https://cloud.cerebras.ai/?utm_source=3pi_pydantic-ai&utm_campaign=partner_doc) and follow your nose until you find the place to generate an API key. + +`CerebrasModelName` contains a list of available Cerebras models. + +## Environment variable + +Once you have the API key, you can set it as an environment variable: + +```bash +export CEREBRAS_API_KEY='your-api-key' +``` + +You can then use `CerebrasModel` by name: + +```python +from pydantic_ai import Agent + +agent = Agent('cerebras:llama-3.3-70b') +... +``` + +Or initialise the model directly with just the model name: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.cerebras import CerebrasModel + +model = CerebrasModel('llama-3.3-70b') +agent = Agent(model) +... +``` + +## `provider` argument + +You can provide a custom `Provider` via the `provider` argument: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.cerebras import CerebrasModel +from pydantic_ai.providers.cerebras import CerebrasProvider + +model = CerebrasModel( + 'llama-3.3-70b', provider=CerebrasProvider(api_key='your-api-key') +) +agent = Agent(model) +... +``` + +You can also customize the `CerebrasProvider` with a custom `httpx.AsyncHTTPClient`: + +```python +from httpx import AsyncClient + +from pydantic_ai import Agent +from pydantic_ai.models.cerebras import CerebrasModel +from pydantic_ai.providers.cerebras import CerebrasProvider + +custom_http_client = AsyncClient(timeout=30) +model = CerebrasModel( + 'llama-3.3-70b', + provider=CerebrasProvider(api_key='your-api-key', http_client=custom_http_client), +) +agent = Agent(model) +... +``` diff --git a/mkdocs.yml b/mkdocs.yml index 5f827ae71b..b1cfbba37f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -30,6 +30,7 @@ nav: - models/anthropic.md - models/google.md - models/bedrock.md + - models/cerebras.md - models/cohere.md - models/groq.md - models/mistral.md diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index b43681b0a4..85d3b6d87d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -134,7 +134,6 @@ 'cerebras:llama-3.3-70b', 'cerebras:llama3.1-8b', 'cerebras:qwen-3-235b-a22b-instruct-2507', - 'cerebras:qwen-3-235b-a22b-thinking-2507', 'cerebras:qwen-3-32b', 'cerebras:zai-glm-4.6', 'cohere:c4ai-aya-expanse-32b', @@ -800,7 +799,6 @@ def infer_model( # noqa: C901 'openai', 'azure', 'deepseek', - 'cerebras', 'fireworks', 'github', 'grok', @@ -818,7 +816,11 @@ def infer_model( # noqa: C901 elif model_kind in ('google-gla', 'google-vertex'): model_kind = 'google' - if model_kind == 'openai-chat': + if model_kind == 'cerebras': + from .cerebras import CerebrasModel + + return CerebrasModel(model_name, provider=provider) # type: ignore[arg-type] + elif model_kind == 'openai-chat': from .openai import OpenAIChatModel return OpenAIChatModel(model_name, provider=provider) diff --git a/pydantic_ai_slim/pydantic_ai/models/cerebras.py b/pydantic_ai_slim/pydantic_ai/models/cerebras.py new file mode 100644 index 0000000000..8f1a16e95f --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/models/cerebras.py @@ -0,0 +1,58 @@ +"""Cerebras model implementation using OpenAI-compatible API.""" + +from __future__ import annotations as _annotations + +from dataclasses import dataclass +from typing import Literal + +from ..profiles import ModelProfileSpec +from ..providers import Provider +from ..settings import ModelSettings +from .openai import OpenAIChatModel + +try: + from openai import AsyncOpenAI +except ImportError as _import_error: # pragma: no cover + raise ImportError( + 'Please install the `openai` package to use the Cerebras model, ' + 'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"' + ) from _import_error + +__all__ = ('CerebrasModel', 'CerebrasModelName') + +CerebrasModelName = Literal[ + 'gpt-oss-120b', + 'llama-3.3-70b', + 'llama3.1-8b', + 'qwen-3-235b-a22b-instruct-2507', + 'qwen-3-32b', + 'zai-glm-4.6', +] + + +@dataclass(init=False) +class CerebrasModel(OpenAIChatModel): + """A model that uses Cerebras's OpenAI-compatible API. + + Cerebras provides ultra-fast inference powered by the Wafer-Scale Engine (WSE). + + Apart from `__init__`, all methods are private or match those of the base class. + """ + + def __init__( + self, + model_name: CerebrasModelName, + *, + provider: Literal['cerebras'] | Provider[AsyncOpenAI] = 'cerebras', + profile: ModelProfileSpec | None = None, + settings: ModelSettings | None = None, + ): + """Initialize a Cerebras model. + + Args: + model_name: The name of the Cerebras model to use. + provider: The provider to use. Defaults to 'cerebras'. + profile: The model profile to use. Defaults to a profile based on the model name. + settings: Model-specific settings that will be used as defaults for this model. + """ + super().__init__(model_name, provider=provider, profile=profile, settings=settings) diff --git a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py index 267cf41b8c..543153bd38 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py +++ b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py @@ -10,7 +10,7 @@ from pydantic_ai.models import cached_async_http_client from pydantic_ai.profiles.harmony import harmony_model_profile from pydantic_ai.profiles.meta import meta_model_profile -from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile +from pydantic_ai.profiles.openai import OpenAIModelProfile from pydantic_ai.profiles.qwen import qwen_model_profile from pydantic_ai.providers import Provider @@ -19,7 +19,7 @@ except ImportError as _import_error: # pragma: no cover raise ImportError( 'Please install the `openai` package to use the Cerebras provider, ' - 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' + 'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"`' ) from _import_error @@ -39,27 +39,22 @@ def client(self) -> AsyncOpenAI: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: - prefix_to_profile = {'llama': meta_model_profile, 'qwen': qwen_model_profile, 'gpt-oss': harmony_model_profile} + prefix_to_profile = { + 'llama': meta_model_profile, + 'qwen': qwen_model_profile, + 'gpt-oss': harmony_model_profile, + } profile = None + model_name_lower = model_name.lower() for prefix, profile_func in prefix_to_profile.items(): - model_name = model_name.lower() - if model_name.startswith(prefix): + if model_name_lower.startswith(prefix): profile = profile_func(model_name) + break - # According to https://inference-docs.cerebras.ai/resources/openai#currently-unsupported-openai-features, - # Cerebras doesn't support some model settings. - unsupported_model_settings = ( - 'frequency_penalty', - 'logit_bias', - 'presence_penalty', - 'parallel_tool_calls', - 'service_tier', - ) - return OpenAIModelProfile( - json_schema_transformer=OpenAIJsonSchemaTransformer, - openai_unsupported_model_settings=unsupported_model_settings, - ).update(profile) + # Wrap in OpenAIModelProfile with web search disabled + # Cerebras doesn't support web search + return OpenAIModelProfile(openai_chat_supports_web_search=False).update(profile) @overload def __init__(self) -> None: ... @@ -70,6 +65,9 @@ def __init__(self, *, api_key: str) -> None: ... @overload def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ... + @overload + def __init__(self, *, http_client: httpx.AsyncClient) -> None: ... + @overload def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... @@ -80,6 +78,14 @@ def __init__( openai_client: AsyncOpenAI | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: + """Create a new Cerebras provider. + + Args: + api_key: The API key to use for authentication, if not provided, the `CEREBRAS_API_KEY` environment variable + will be used if available. + openai_client: An existing `AsyncOpenAI` client to use. If provided, `api_key` and `http_client` must be `None`. + http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. + """ api_key = api_key or os.getenv('CEREBRAS_API_KEY') if not api_key and openai_client is None: raise UserError( diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 86b7d65e05..0dda8adaec 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -72,6 +72,7 @@ cohere = ["cohere>=5.18.0; platform_system != 'Emscripten'"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"] google = ["google-genai>=1.51.0"] anthropic = ["anthropic>=0.70.0"] +cerebras = ["openai>=1.107.2"] groq = ["groq>=0.25.0"] mistral = ["mistralai>=1.9.10"] bedrock = ["boto3>=1.40.14"] diff --git a/tests/models/cassettes/test_model_names/test_known_model_names.yaml b/tests/models/cassettes/test_model_names/test_known_model_names.yaml index 1c55c62948..3bed48efb1 100644 --- a/tests/models/cassettes/test_model_names/test_known_model_names.yaml +++ b/tests/models/cassettes/test_model_names/test_known_model_names.yaml @@ -124,7 +124,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 content-length: - - '570' + - '479' content-type: - application/json referrer-policy: @@ -133,24 +133,20 @@ interactions: - max-age=3600; includeSubDomains parsed_body: data: - - created: 0 - id: qwen-3-235b-a22b-thinking-2507 - object: model - owned_by: Cerebras - created: 0 id: llama-3.3-70b object: model owned_by: Cerebras - created: 0 - id: qwen-3-235b-a22b-instruct-2507 + id: llama3.1-8b object: model owned_by: Cerebras - created: 0 - id: qwen-3-32b + id: zai-glm-4.6 object: model owned_by: Cerebras - created: 0 - id: zai-glm-4.6 + id: qwen-3-32b object: model owned_by: Cerebras - created: 0 @@ -158,7 +154,7 @@ interactions: object: model owned_by: Cerebras - created: 0 - id: llama3.1-8b + id: qwen-3-235b-a22b-instruct-2507 object: model owned_by: Cerebras object: list diff --git a/tests/models/test_cerebras.py b/tests/models/test_cerebras.py new file mode 100644 index 0000000000..f85bab4ca9 --- /dev/null +++ b/tests/models/test_cerebras.py @@ -0,0 +1,73 @@ +import re + +import pytest + +from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer +from pydantic_ai.exceptions import UserError +from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile + +from ..conftest import TestEnv, try_import + +with try_import() as imports_successful: + from openai import AsyncOpenAI + + from pydantic_ai.models.cerebras import CerebrasModel + from pydantic_ai.providers.cerebras import CerebrasProvider + +pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed') + + +def test_cerebras_model_init(): + model = CerebrasModel('llama-3.3-70b', provider=CerebrasProvider(api_key='test_key')) + assert model.model_name == 'llama-3.3-70b' + assert isinstance(model._provider, CerebrasProvider) # type: ignore[reportPrivateUsage] + assert model._provider.client.api_key == 'test_key' # type: ignore[reportPrivateUsage] + + +def test_cerebras_model_profile(): + provider = CerebrasProvider(api_key='test_key') + + # Test Llama model + model = CerebrasModel('llama-3.3-70b', provider=provider) + profile = model.profile + assert isinstance(profile, OpenAIModelProfile) + assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer + assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False + + # Test Qwen model + model = CerebrasModel('qwen-3-235b-a22b-instruct-2507', provider=provider) + profile = model.profile + assert isinstance(profile, OpenAIModelProfile) + assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer + assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False + + # Test GPT-OSS model + model = CerebrasModel('gpt-oss-120b', provider=provider) + profile = model.profile + assert isinstance(profile, OpenAIModelProfile) + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False + + # Test unknown model - use zai-glm which is valid but won't match any prefix + model = CerebrasModel('zai-glm-4.6', provider=provider) + profile = model.profile + assert isinstance(profile, OpenAIModelProfile) + assert OpenAIModelProfile.from_profile(profile).openai_chat_supports_web_search is False + + +def test_cerebras_provider_need_api_key(env: TestEnv) -> None: + env.remove('CEREBRAS_API_KEY') + with pytest.raises( + UserError, + match=re.escape( + 'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` ' + 'to use the Cerebras provider.' + ), + ): + CerebrasProvider() + + +def test_cerebras_provider_pass_openai_client() -> None: + openai_client = AsyncOpenAI(api_key='test_key', base_url='https://api.cerebras.ai/v1') + provider = CerebrasProvider(openai_client=openai_client) + assert provider.client == openai_client diff --git a/tests/providers/test_cerebras.py b/tests/providers/test_cerebras.py deleted file mode 100644 index 3dcf04d3f4..0000000000 --- a/tests/providers/test_cerebras.py +++ /dev/null @@ -1,86 +0,0 @@ -import re - -import httpx -import pytest -from pytest_mock import MockerFixture - -from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer -from pydantic_ai.exceptions import UserError -from pydantic_ai.profiles.harmony import harmony_model_profile -from pydantic_ai.profiles.meta import meta_model_profile -from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer -from pydantic_ai.profiles.qwen import qwen_model_profile - -from ..conftest import TestEnv, try_import - -with try_import() as imports_successful: - import openai - - from pydantic_ai.providers.cerebras import CerebrasProvider - -pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed') - - -def test_cerebras_provider(): - provider = CerebrasProvider(api_key='ghp_test_token') - assert provider.name == 'cerebras' - assert provider.base_url == 'https://api.cerebras.ai/v1' - assert isinstance(provider.client, openai.AsyncOpenAI) - assert provider.client.api_key == 'ghp_test_token' - - -def test_cerebras_provider_need_api_key(env: TestEnv) -> None: - env.remove('CEREBRAS_API_KEY') - with pytest.raises( - UserError, - match=re.escape( - 'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` ' - 'to use the Cerebras provider.' - ), - ): - CerebrasProvider() - - -def test_github_provider_pass_http_client() -> None: - http_client = httpx.AsyncClient() - provider = CerebrasProvider(http_client=http_client, api_key='ghp_test_token') - assert provider.client._client == http_client # type: ignore[reportPrivateUsage] - - -def test_github_pass_openai_client() -> None: - openai_client = openai.AsyncOpenAI(api_key='ghp_test_token') - provider = CerebrasProvider(openai_client=openai_client) - assert provider.client == openai_client - - -def test_cerebras_provider_model_profile(mocker: MockerFixture): - provider = CerebrasProvider(api_key='api-key') - - ns = 'pydantic_ai.providers.cerebras' - meta_model_profile_mock = mocker.patch(f'{ns}.meta_model_profile', wraps=meta_model_profile) - qwen_model_profile_mock = mocker.patch(f'{ns}.qwen_model_profile', wraps=qwen_model_profile) - harmony_model_profile_mock = mocker.patch(f'{ns}.harmony_model_profile', wraps=harmony_model_profile) - - meta_profile = provider.model_profile('llama4-maverick-instruct-basic') - meta_model_profile_mock.assert_called_with('llama4-maverick-instruct-basic') - assert meta_profile is not None - assert meta_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer - - qwen_profile = provider.model_profile('qwen3-235b-a22b') - qwen_model_profile_mock.assert_called_with('qwen3-235b-a22b') - assert qwen_profile is not None - assert qwen_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer - - qwen_profile = provider.model_profile('qwen-3-coder-480b') - qwen_model_profile_mock.assert_called_with('qwen-3-coder-480b') - assert qwen_profile is not None - assert qwen_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer - - openai_profile = provider.model_profile('gpt-oss-120b') - harmony_model_profile_mock.assert_called_with('gpt-oss-120b') - assert openai_profile is not None - assert openai_profile.json_schema_transformer == OpenAIJsonSchemaTransformer - - unknown_profile = provider.model_profile('unknown-model') - assert unknown_profile is not None - assert unknown_profile.json_schema_transformer == OpenAIJsonSchemaTransformer diff --git a/uv.lock b/uv.lock index a3738b262b..27ea56a733 100644 --- a/uv.lock +++ b/uv.lock @@ -5569,6 +5569,9 @@ anthropic = [ bedrock = [ { name = "boto3" }, ] +cerebras = [ + { name = "openai" }, +] cli = [ { name = "argcomplete" }, { name = "prompt-toolkit" }, @@ -5678,6 +5681,7 @@ requires-dist = [ { name = "logfire", extras = ["httpx"], marker = "extra == 'logfire'", specifier = ">=3.14.1" }, { name = "mcp", marker = "extra == 'mcp'", specifier = ">=1.18.0" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.9.10" }, + { name = "openai", marker = "extra == 'cerebras'", specifier = ">=1.107.2" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.107.2" }, { name = "opentelemetry-api", specifier = ">=1.28.0" }, { name = "outlines", marker = "extra == 'outlines-vllm-offline'", specifier = ">=1.0.0,<1.3.0" }, @@ -5706,7 +5710,7 @@ requires-dist = [ { name = "typing-inspection", specifier = ">=0.4.0" }, { name = "vllm", marker = "(python_full_version < '3.12' and platform_machine != 'x86_64' and extra == 'outlines-vllm-offline') or (python_full_version < '3.12' and sys_platform != 'darwin' and extra == 'outlines-vllm-offline')" }, ] -provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "dbos", "duckduckgo", "evals", "fastmcp", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "outlines-llamacpp", "outlines-mlxlm", "outlines-sglang", "outlines-transformers", "outlines-vllm-offline", "prefect", "retries", "tavily", "temporal", "ui", "vertexai"] +provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cerebras", "cli", "cohere", "dbos", "duckduckgo", "evals", "fastmcp", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "outlines-llamacpp", "outlines-mlxlm", "outlines-sglang", "outlines-transformers", "outlines-vllm-offline", "prefect", "retries", "tavily", "temporal", "ui", "vertexai"] [[package]] name = "pydantic-core"