Skip to content

Allow proper type on AnthropicProvider when using Bedrock #2490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/evals.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,8 @@ async def double_number(input_value: int) -> int:
# Run evaluation with unlimited concurrency
t0 = time.time()
report_default = dataset.evaluate_sync(double_number)
print(f'Evaluation took less than 0.5s: {time.time() - t0 < 0.5}')
#> Evaluation took less than 0.5s: True
print(f'Evaluation took less than 1s: {time.time() - t0 < 1}')
#> Evaluation took less than 1s: True

Comment on lines +344 to 346
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've applied this in another PR.

report_default.print(include_input=True, include_output=True, include_durations=False) # (1)!
"""
Expand Down
9 changes: 5 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from ..profiles import ModelProfileSpec
from ..providers import Provider, infer_provider
from ..providers.anthropic import ASYNC_ANTHROPIC_CLIENT
from ..settings import ModelSettings
from ..tools import ToolDefinition
from . import (
Expand All @@ -56,7 +57,7 @@
)

try:
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
from anthropic import NOT_GIVEN, APIStatusError, AsyncStream
from anthropic.types.beta import (
BetaBase64PDFBlockParam,
BetaBase64PDFSourceParam,
Expand Down Expand Up @@ -142,7 +143,7 @@ class AnthropicModel(Model):
Apart from `__init__`, all methods are private or match those of the base class.
"""

client: AsyncAnthropic = field(repr=False)
client: ASYNC_ANTHROPIC_CLIENT = field(repr=False)

_model_name: AnthropicModelName = field(repr=False)
_system: str = field(default='anthropic', repr=False)
Expand All @@ -151,7 +152,7 @@ def __init__(
self,
model_name: AnthropicModelName,
*,
provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
provider: Literal['anthropic'] | Provider[ASYNC_ANTHROPIC_CLIENT] = 'anthropic',
profile: ModelProfileSpec | None = None,
settings: ModelSettings | None = None,
):
Expand All @@ -161,7 +162,7 @@ def __init__(
model_name: The name of the Anthropic model to use. List of model names available
[here](https://docs.anthropic.com/en/docs/about-claude/models).
provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
instance of `Provider[ASYNC_ANTHROPIC_CLIENT]`. If not provided, the other parameters will be used.
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
settings: Default model settings for this model instance.
"""
Expand Down
75 changes: 58 additions & 17 deletions pydantic_ai_slim/pydantic_ai/providers/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
from __future__ import annotations as _annotations

import os
from typing import overload
from typing import Union, overload

import httpx

from pydantic_ai.exceptions import UserError
from pydantic_ai.models import cached_async_http_client
from pydantic_ai.profiles import ModelProfile
from pydantic_ai.profiles.anthropic import anthropic_model_profile
from pydantic_ai.providers import Provider

try:
from anthropic import AsyncAnthropic
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
except ImportError as _import_error: # pragma: no cover
raise ImportError(
'Please install the `anthropic` package to use the Anthropic provider, '
'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`'
) from _import_error

from typing_extensions import TypeAlias
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move this import up (before the try..except).


class AnthropicProvider(Provider[AsyncAnthropic]):
ASYNC_ANTHROPIC_CLIENT: TypeAlias = Union[AsyncAnthropic, AsyncAnthropicBedrock]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really union what we want?

Suggested change
ASYNC_ANTHROPIC_CLIENT: TypeAlias = Union[AsyncAnthropic, AsyncAnthropicBedrock]
AsyncAnthropicClient = TypeVar("AsyncAnthropicClient", AsyncAnthropic, AsyncAnthropicBedrock)

Does this work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use full uppercase for type aliases, that's for constants. 🙏



class AnthropicProvider(Provider[ASYNC_ANTHROPIC_CLIENT]):
"""Provider for Anthropic API."""

@property
Expand All @@ -32,30 +35,52 @@ def base_url(self) -> str:
return str(self._client.base_url)

@property
def client(self) -> AsyncAnthropic:
def client(self) -> ASYNC_ANTHROPIC_CLIENT:
return self._client

def model_profile(self, model_name: str) -> ModelProfile | None:
return anthropic_model_profile(model_name)

@overload
def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ...
def __init__(self, *, anthropic_client: ASYNC_ANTHROPIC_CLIENT | None = None) -> None: ...

@overload
def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ...

@overload
def __init__(
self,
*,
aws_secret_key: str | None = None,
aws_access_key: str | None = None,
aws_region: str | None = None,
aws_profile: str | None = None,
aws_session_token: str | None = None,
Comment on lines +54 to +58
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to not complicate things here.

The user can just pass the client when using Bedrock.

http_client: httpx.AsyncClient | None = None,
) -> None: ...

def __init__(
self,
*,
api_key: str | None = None,
anthropic_client: AsyncAnthropic | None = None,
aws_secret_key: str | None = None,
aws_access_key: str | None = None,
aws_region: str | None = None,
aws_profile: str | None = None,
aws_session_token: str | None = None,
anthropic_client: ASYNC_ANTHROPIC_CLIENT | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
"""Create a new Anthropic provider.

Args:
api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
will be used if available.
aws_secret_key: AWS secret access key for Bedrock authentication.
aws_access_key: AWS access key ID for Bedrock authentication.
aws_region: AWS region for Bedrock service.
aws_profile: AWS profile name for Bedrock authentication.
aws_session_token: AWS session token for temporary credentials.
anthropic_client: An existing [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python)
client to use. If provided, the `api_key` and `http_client` arguments will be ignored.
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
Expand All @@ -66,14 +91,30 @@ def __init__(
self._client = anthropic_client
else:
api_key = api_key or os.environ.get('ANTHROPIC_API_KEY')
if not api_key:
raise UserError(
'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`'
'to use the Anthropic provider.'
)

if http_client is not None:
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
if api_key is None:
if http_client is not None:
self._client = AsyncAnthropicBedrock(
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_profile=aws_profile,
aws_region=aws_region,
http_client=http_client,
)
else:
http_client = cached_async_http_client(provider='anthropic')
self._client = AsyncAnthropicBedrock(
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_profile=aws_profile,
aws_region=aws_region,
http_client=http_client,
)

else:
http_client = cached_async_http_client(provider='anthropic')
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
if http_client is not None:
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
else:
http_client = cached_async_http_client(provider='anthropic')
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
7 changes: 6 additions & 1 deletion tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from .mock_async_stream import MockAsyncStream

with try_import() as imports_successful:
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncAnthropicBedrock
from anthropic.resources.beta import AsyncBeta
from anthropic.types.beta import (
BetaCodeExecutionResultBlock,
Expand Down Expand Up @@ -95,10 +95,15 @@

def test_init():
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key='foobar'))
assert isinstance(m.client, AsyncAnthropic)
assert m.client.api_key == 'foobar'
assert m.model_name == 'claude-3-5-haiku-latest'
assert m.system == 'anthropic'
assert m.base_url == 'https://api.anthropic.com'
bedrock = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider())
assert isinstance(bedrock.client, AsyncAnthropicBedrock)
assert m.model_name == 'claude-3-5-haiku-latest'
assert m.system == 'anthropic'


@dataclass
Expand Down
53 changes: 45 additions & 8 deletions tests/providers/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
import httpx
import pytest

from pydantic_ai.exceptions import UserError

from ..conftest import TestEnv, try_import
from ..conftest import try_import

with try_import() as imports_successful:
from anthropic import AsyncAnthropic
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock

from pydantic_ai.providers.anthropic import AnthropicProvider

Expand All @@ -24,10 +22,17 @@ def test_anthropic_provider():
assert provider.client.api_key == 'api-key'


def test_anthropic_provider_need_api_key(env: TestEnv) -> None:
env.remove('ANTHROPIC_API_KEY')
with pytest.raises(UserError, match=r'.*ANTHROPIC_API_KEY.*'):
AnthropicProvider()
def test_anthropic_provider_with_aws_credentials() -> None:
provider = AnthropicProvider(
aws_secret_key='aws-secret-key',
aws_access_key='aws-access-key',
aws_region='us-west-2',
aws_profile='default',
aws_session_token='aws-session-token',
)
assert provider.name == 'anthropic'
assert provider.base_url == 'https://bedrock-runtime.us-west-2.amazonaws.com'
assert isinstance(provider.client, AsyncAnthropicBedrock)


def test_anthropic_provider_pass_http_client() -> None:
Expand All @@ -36,12 +41,31 @@ def test_anthropic_provider_pass_http_client() -> None:
assert isinstance(provider.client, AsyncAnthropic)
# Verify the http_client is being used by the AsyncAnthropic client
assert provider.client._client == http_client # type: ignore[reportPrivateUsage]
bedrock_provider = AnthropicProvider(
aws_secret_key='aws-secret-key',
aws_access_key='aws-access_key',
aws_region='us-west-2',
aws_profile='default',
aws_session_token='aws-session-token',
http_client=http_client,
)
assert isinstance(bedrock_provider.client, AsyncAnthropicBedrock)
assert bedrock_provider.client._client == http_client # type: ignore[reportPrivateUsage]


def test_anthropic_provider_pass_anthropic_client() -> None:
anthropic_client = AsyncAnthropic(api_key='api-key')
provider = AnthropicProvider(anthropic_client=anthropic_client)
assert provider.client == anthropic_client
bedrock_client = AsyncAnthropicBedrock(
aws_secret_key='aws-secret-key',
aws_access_key='aws-access-key',
aws_region='us-west-2',
aws_profile='default',
aws_session_token='aws-session-token',
)
provider = AnthropicProvider(anthropic_client=bedrock_client)
assert provider.client == bedrock_client


def test_anthropic_provider_with_env_base_url(monkeypatch: pytest.MonkeyPatch) -> None:
Expand All @@ -50,3 +74,16 @@ def test_anthropic_provider_with_env_base_url(monkeypatch: pytest.MonkeyPatch) -
monkeypatch.setenv('ANTHROPIC_BASE_URL', custom_base_url)
provider = AnthropicProvider(api_key='api-key')
assert provider.base_url.rstrip('/') == custom_base_url.rstrip('/')


def test_bedrock_anthropic_provider_with_envs(monkeypatch: pytest.MonkeyPatch) -> None:
# Test with environment variables for AWS credentials
monkeypatch.setenv('AWS_SECRET_ACCESS_KEY', 'aws-secret-access-key')
monkeypatch.setenv('AWS_ACCESS_KEY_ID', 'aws-access-key-id')
monkeypatch.setenv('AWS_SESSION_TOKEN', 'aws-session-token')
monkeypatch.setenv('AWS_PROFILE', 'default')
monkeypatch.setenv('AWS_REGION', 'us-west-2')
bedrock_provider = AnthropicProvider()
assert bedrock_provider.name == 'anthropic'
assert bedrock_provider.base_url == 'https://bedrock-runtime.us-west-2.amazonaws.com'
assert isinstance(bedrock_provider.client, AsyncAnthropicBedrock)
4 changes: 1 addition & 3 deletions tests/providers/test_provider_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from pydantic_ai.providers.anthropic import AnthropicProvider
from pydantic_ai.providers.azure import AzureProvider
from pydantic_ai.providers.cohere import CohereProvider
from pydantic_ai.providers.deepseek import DeepSeekProvider
from pydantic_ai.providers.fireworks import FireworksProvider
from pydantic_ai.providers.github import GitHubProvider
Expand All @@ -33,8 +32,7 @@
from pydantic_ai.providers.vercel import VercelProvider

test_infer_provider_params = [
('anthropic', AnthropicProvider, 'ANTHROPIC_API_KEY'),
('cohere', CohereProvider, 'CO_API_KEY'),
('anthropic', AnthropicProvider, None),
('deepseek', DeepSeekProvider, 'DEEPSEEK_API_KEY'),
('openrouter', OpenRouterProvider, 'OPENROUTER_API_KEY'),
('vercel', VercelProvider, 'VERCEL_AI_GATEWAY_API_KEY'),
Expand Down