diff --git a/docs/evals.md b/docs/evals.md index e1820b88c..81112e69b 100644 --- a/docs/evals.md +++ b/docs/evals.md @@ -341,8 +341,8 @@ async def double_number(input_value: int) -> int: # Run evaluation with unlimited concurrency t0 = time.time() report_default = dataset.evaluate_sync(double_number) -print(f'Evaluation took less than 0.5s: {time.time() - t0 < 0.5}') -#> Evaluation took less than 0.5s: True +print(f'Evaluation took less than 1s: {time.time() - t0 < 1}') +#> Evaluation took less than 1s: True report_default.print(include_input=True, include_output=True, include_durations=False) # (1)! """ diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 9eaae9352..afb006af7 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -44,6 +44,7 @@ ) from ..profiles import ModelProfileSpec from ..providers import Provider, infer_provider +from ..providers.anthropic import ASYNC_ANTHROPIC_CLIENT from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -56,7 +57,7 @@ ) try: - from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream + from anthropic import NOT_GIVEN, APIStatusError, AsyncStream from anthropic.types.beta import ( BetaBase64PDFBlockParam, BetaBase64PDFSourceParam, @@ -142,7 +143,7 @@ class AnthropicModel(Model): Apart from `__init__`, all methods are private or match those of the base class. """ - client: AsyncAnthropic = field(repr=False) + client: ASYNC_ANTHROPIC_CLIENT = field(repr=False) _model_name: AnthropicModelName = field(repr=False) _system: str = field(default='anthropic', repr=False) @@ -151,7 +152,7 @@ def __init__( self, model_name: AnthropicModelName, *, - provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic', + provider: Literal['anthropic'] | Provider[ASYNC_ANTHROPIC_CLIENT] = 'anthropic', profile: ModelProfileSpec | None = None, settings: ModelSettings | None = None, ): @@ -161,7 +162,7 @@ def __init__( model_name: The name of the Anthropic model to use. List of model names available [here](https://docs.anthropic.com/en/docs/about-claude/models). provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an - instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used. + instance of `Provider[ASYNC_ANTHROPIC_CLIENT]`. If not provided, the other parameters will be used. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. settings: Default model settings for this model instance. """ diff --git a/pydantic_ai_slim/pydantic_ai/providers/anthropic.py b/pydantic_ai_slim/pydantic_ai/providers/anthropic.py index 20bc3255e..85afcee12 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/providers/anthropic.py @@ -1,26 +1,29 @@ from __future__ import annotations as _annotations import os -from typing import overload +from typing import Union, overload import httpx -from pydantic_ai.exceptions import UserError from pydantic_ai.models import cached_async_http_client from pydantic_ai.profiles import ModelProfile from pydantic_ai.profiles.anthropic import anthropic_model_profile from pydantic_ai.providers import Provider try: - from anthropic import AsyncAnthropic + from anthropic import AsyncAnthropic, AsyncAnthropicBedrock except ImportError as _import_error: # pragma: no cover raise ImportError( 'Please install the `anthropic` package to use the Anthropic provider, ' 'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`' ) from _import_error +from typing_extensions import TypeAlias -class AnthropicProvider(Provider[AsyncAnthropic]): +ASYNC_ANTHROPIC_CLIENT: TypeAlias = Union[AsyncAnthropic, AsyncAnthropicBedrock] + + +class AnthropicProvider(Provider[ASYNC_ANTHROPIC_CLIENT]): """Provider for Anthropic API.""" @property @@ -32,23 +35,40 @@ def base_url(self) -> str: return str(self._client.base_url) @property - def client(self) -> AsyncAnthropic: + def client(self) -> ASYNC_ANTHROPIC_CLIENT: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: return anthropic_model_profile(model_name) @overload - def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ... + def __init__(self, *, anthropic_client: ASYNC_ANTHROPIC_CLIENT | None = None) -> None: ... @overload def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ... + @overload + def __init__( + self, + *, + aws_secret_key: str | None = None, + aws_access_key: str | None = None, + aws_region: str | None = None, + aws_profile: str | None = None, + aws_session_token: str | None = None, + http_client: httpx.AsyncClient | None = None, + ) -> None: ... + def __init__( self, *, api_key: str | None = None, - anthropic_client: AsyncAnthropic | None = None, + aws_secret_key: str | None = None, + aws_access_key: str | None = None, + aws_region: str | None = None, + aws_profile: str | None = None, + aws_session_token: str | None = None, + anthropic_client: ASYNC_ANTHROPIC_CLIENT | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: """Create a new Anthropic provider. @@ -56,6 +76,11 @@ def __init__( Args: api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable will be used if available. + aws_secret_key: AWS secret access key for Bedrock authentication. + aws_access_key: AWS access key ID for Bedrock authentication. + aws_region: AWS region for Bedrock service. + aws_profile: AWS profile name for Bedrock authentication. + aws_session_token: AWS session token for temporary credentials. anthropic_client: An existing [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python) client to use. If provided, the `api_key` and `http_client` arguments will be ignored. http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. @@ -66,14 +91,30 @@ def __init__( self._client = anthropic_client else: api_key = api_key or os.environ.get('ANTHROPIC_API_KEY') - if not api_key: - raise UserError( - 'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`' - 'to use the Anthropic provider.' - ) - - if http_client is not None: - self._client = AsyncAnthropic(api_key=api_key, http_client=http_client) + if api_key is None: + if http_client is not None: + self._client = AsyncAnthropicBedrock( + aws_access_key=aws_access_key, + aws_secret_key=aws_secret_key, + aws_session_token=aws_session_token, + aws_profile=aws_profile, + aws_region=aws_region, + http_client=http_client, + ) + else: + http_client = cached_async_http_client(provider='anthropic') + self._client = AsyncAnthropicBedrock( + aws_access_key=aws_access_key, + aws_secret_key=aws_secret_key, + aws_session_token=aws_session_token, + aws_profile=aws_profile, + aws_region=aws_region, + http_client=http_client, + ) + else: - http_client = cached_async_http_client(provider='anthropic') - self._client = AsyncAnthropic(api_key=api_key, http_client=http_client) + if http_client is not None: + self._client = AsyncAnthropic(api_key=api_key, http_client=http_client) + else: + http_client = cached_async_http_client(provider='anthropic') + self._client = AsyncAnthropic(api_key=api_key, http_client=http_client) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 3e1b06301..982be866d 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -47,7 +47,7 @@ from .mock_async_stream import MockAsyncStream with try_import() as imports_successful: - from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic + from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncAnthropicBedrock from anthropic.resources.beta import AsyncBeta from anthropic.types.beta import ( BetaCodeExecutionResultBlock, @@ -95,10 +95,15 @@ def test_init(): m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key='foobar')) + assert isinstance(m.client, AsyncAnthropic) assert m.client.api_key == 'foobar' assert m.model_name == 'claude-3-5-haiku-latest' assert m.system == 'anthropic' assert m.base_url == 'https://api.anthropic.com' + bedrock = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider()) + assert isinstance(bedrock.client, AsyncAnthropicBedrock) + assert m.model_name == 'claude-3-5-haiku-latest' + assert m.system == 'anthropic' @dataclass diff --git a/tests/providers/test_anthropic.py b/tests/providers/test_anthropic.py index 44f47554b..c68677aa3 100644 --- a/tests/providers/test_anthropic.py +++ b/tests/providers/test_anthropic.py @@ -3,12 +3,10 @@ import httpx import pytest -from pydantic_ai.exceptions import UserError - -from ..conftest import TestEnv, try_import +from ..conftest import try_import with try_import() as imports_successful: - from anthropic import AsyncAnthropic + from anthropic import AsyncAnthropic, AsyncAnthropicBedrock from pydantic_ai.providers.anthropic import AnthropicProvider @@ -24,10 +22,17 @@ def test_anthropic_provider(): assert provider.client.api_key == 'api-key' -def test_anthropic_provider_need_api_key(env: TestEnv) -> None: - env.remove('ANTHROPIC_API_KEY') - with pytest.raises(UserError, match=r'.*ANTHROPIC_API_KEY.*'): - AnthropicProvider() +def test_anthropic_provider_with_aws_credentials() -> None: + provider = AnthropicProvider( + aws_secret_key='aws-secret-key', + aws_access_key='aws-access-key', + aws_region='us-west-2', + aws_profile='default', + aws_session_token='aws-session-token', + ) + assert provider.name == 'anthropic' + assert provider.base_url == 'https://bedrock-runtime.us-west-2.amazonaws.com' + assert isinstance(provider.client, AsyncAnthropicBedrock) def test_anthropic_provider_pass_http_client() -> None: @@ -36,12 +41,31 @@ def test_anthropic_provider_pass_http_client() -> None: assert isinstance(provider.client, AsyncAnthropic) # Verify the http_client is being used by the AsyncAnthropic client assert provider.client._client == http_client # type: ignore[reportPrivateUsage] + bedrock_provider = AnthropicProvider( + aws_secret_key='aws-secret-key', + aws_access_key='aws-access_key', + aws_region='us-west-2', + aws_profile='default', + aws_session_token='aws-session-token', + http_client=http_client, + ) + assert isinstance(bedrock_provider.client, AsyncAnthropicBedrock) + assert bedrock_provider.client._client == http_client # type: ignore[reportPrivateUsage] def test_anthropic_provider_pass_anthropic_client() -> None: anthropic_client = AsyncAnthropic(api_key='api-key') provider = AnthropicProvider(anthropic_client=anthropic_client) assert provider.client == anthropic_client + bedrock_client = AsyncAnthropicBedrock( + aws_secret_key='aws-secret-key', + aws_access_key='aws-access-key', + aws_region='us-west-2', + aws_profile='default', + aws_session_token='aws-session-token', + ) + provider = AnthropicProvider(anthropic_client=bedrock_client) + assert provider.client == bedrock_client def test_anthropic_provider_with_env_base_url(monkeypatch: pytest.MonkeyPatch) -> None: @@ -50,3 +74,16 @@ def test_anthropic_provider_with_env_base_url(monkeypatch: pytest.MonkeyPatch) - monkeypatch.setenv('ANTHROPIC_BASE_URL', custom_base_url) provider = AnthropicProvider(api_key='api-key') assert provider.base_url.rstrip('/') == custom_base_url.rstrip('/') + + +def test_bedrock_anthropic_provider_with_envs(monkeypatch: pytest.MonkeyPatch) -> None: + # Test with environment variables for AWS credentials + monkeypatch.setenv('AWS_SECRET_ACCESS_KEY', 'aws-secret-access-key') + monkeypatch.setenv('AWS_ACCESS_KEY_ID', 'aws-access-key-id') + monkeypatch.setenv('AWS_SESSION_TOKEN', 'aws-session-token') + monkeypatch.setenv('AWS_PROFILE', 'default') + monkeypatch.setenv('AWS_REGION', 'us-west-2') + bedrock_provider = AnthropicProvider() + assert bedrock_provider.name == 'anthropic' + assert bedrock_provider.base_url == 'https://bedrock-runtime.us-west-2.amazonaws.com' + assert isinstance(bedrock_provider.client, AsyncAnthropicBedrock) diff --git a/tests/providers/test_provider_names.py b/tests/providers/test_provider_names.py index 0ace7835e..25f1c456a 100644 --- a/tests/providers/test_provider_names.py +++ b/tests/providers/test_provider_names.py @@ -16,7 +16,6 @@ from pydantic_ai.providers.anthropic import AnthropicProvider from pydantic_ai.providers.azure import AzureProvider - from pydantic_ai.providers.cohere import CohereProvider from pydantic_ai.providers.deepseek import DeepSeekProvider from pydantic_ai.providers.fireworks import FireworksProvider from pydantic_ai.providers.github import GitHubProvider @@ -33,8 +32,7 @@ from pydantic_ai.providers.vercel import VercelProvider test_infer_provider_params = [ - ('anthropic', AnthropicProvider, 'ANTHROPIC_API_KEY'), - ('cohere', CohereProvider, 'CO_API_KEY'), + ('anthropic', AnthropicProvider, None), ('deepseek', DeepSeekProvider, 'DEEPSEEK_API_KEY'), ('openrouter', OpenRouterProvider, 'OPENROUTER_API_KEY'), ('vercel', VercelProvider, 'VERCEL_AI_GATEWAY_API_KEY'),