Skip to content

Commit d3d3ded

Browse files
akoshelKludex
andauthored
Allow proper type on AnthropicProvider when using Bedrock (#2490)
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent 8828587 commit d3d3ded

File tree

4 files changed

+29
-32
lines changed

4 files changed

+29
-32
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,13 @@
3737
)
3838
from ..profiles import ModelProfileSpec
3939
from ..providers import Provider, infer_provider
40+
from ..providers.anthropic import AsyncAnthropicClient
4041
from ..settings import ModelSettings
4142
from ..tools import ToolDefinition
4243
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
4344

4445
try:
45-
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
46+
from anthropic import NOT_GIVEN, APIStatusError, AsyncStream
4647
from anthropic.types.beta import (
4748
BetaBase64PDFBlockParam,
4849
BetaBase64PDFSourceParam,
@@ -134,16 +135,16 @@ class AnthropicModel(Model):
134135
Apart from `__init__`, all methods are private or match those of the base class.
135136
"""
136137

137-
client: AsyncAnthropic = field(repr=False)
138+
client: AsyncAnthropicClient = field(repr=False)
138139

139140
_model_name: AnthropicModelName = field(repr=False)
140-
_provider: Provider[AsyncAnthropic] = field(repr=False)
141+
_provider: Provider[AsyncAnthropicClient] = field(repr=False)
141142

142143
def __init__(
143144
self,
144145
model_name: AnthropicModelName,
145146
*,
146-
provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
147+
provider: Literal['anthropic'] | Provider[AsyncAnthropicClient] = 'anthropic',
147148
profile: ModelProfileSpec | None = None,
148149
settings: ModelSettings | None = None,
149150
):
@@ -153,7 +154,7 @@ def __init__(
153154
model_name: The name of the Anthropic model to use. List of model names available
154155
[here](https://docs.anthropic.com/en/docs/about-claude/models).
155156
provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
156-
instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
157+
instance of `Provider[AsyncAnthropicClient]`. If not provided, the other parameters will be used.
157158
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
158159
settings: Default model settings for this model instance.
159160
"""

pydantic_ai_slim/pydantic_ai/providers/anthropic.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations as _annotations
22

33
import os
4-
from typing import overload
4+
from typing import Union, overload
55

66
import httpx
7+
from typing_extensions import TypeAlias
78

89
from pydantic_ai.exceptions import UserError
910
from pydantic_ai.models import cached_async_http_client
@@ -12,15 +13,18 @@
1213
from pydantic_ai.providers import Provider
1314

1415
try:
15-
from anthropic import AsyncAnthropic
16-
except ImportError as _import_error: # pragma: no cover
16+
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
17+
except ImportError as _import_error:
1718
raise ImportError(
1819
'Please install the `anthropic` package to use the Anthropic provider, '
1920
'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`'
2021
) from _import_error
2122

2223

23-
class AnthropicProvider(Provider[AsyncAnthropic]):
24+
AsyncAnthropicClient: TypeAlias = Union[AsyncAnthropic, AsyncAnthropicBedrock]
25+
26+
27+
class AnthropicProvider(Provider[AsyncAnthropicClient]):
2428
"""Provider for Anthropic API."""
2529

2630
@property
@@ -32,14 +36,14 @@ def base_url(self) -> str:
3236
return str(self._client.base_url)
3337

3438
@property
35-
def client(self) -> AsyncAnthropic:
39+
def client(self) -> AsyncAnthropicClient:
3640
return self._client
3741

3842
def model_profile(self, model_name: str) -> ModelProfile | None:
3943
return anthropic_model_profile(model_name)
4044

4145
@overload
42-
def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ...
46+
def __init__(self, *, anthropic_client: AsyncAnthropicClient | None = None) -> None: ...
4347

4448
@overload
4549
def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ...
@@ -48,7 +52,7 @@ def __init__(
4852
self,
4953
*,
5054
api_key: str | None = None,
51-
anthropic_client: AsyncAnthropic | None = None,
55+
anthropic_client: AsyncAnthropicClient | None = None,
5256
http_client: httpx.AsyncClient | None = None,
5357
) -> None:
5458
"""Create a new Anthropic provider.
@@ -71,7 +75,6 @@ def __init__(
7175
'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`'
7276
'to use the Anthropic provider.'
7377
)
74-
7578
if http_client is not None:
7679
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
7780
else:

tests/models/test_anthropic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797

9898
def test_init():
9999
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key='foobar'))
100+
assert isinstance(m.client, AsyncAnthropic)
100101
assert m.client.api_key == 'foobar'
101102
assert m.model_name == 'claude-3-5-haiku-latest'
102103
assert m.system == 'anthropic'

tests/providers/test_anthropic.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
from __future__ import annotations as _annotations
22

3-
import httpx
43
import pytest
54

6-
from pydantic_ai.exceptions import UserError
7-
8-
from ..conftest import TestEnv, try_import
5+
from ..conftest import try_import
96

107
with try_import() as imports_successful:
11-
from anthropic import AsyncAnthropic
8+
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
129

1310
from pydantic_ai.providers.anthropic import AnthropicProvider
1411

@@ -24,24 +21,19 @@ def test_anthropic_provider():
2421
assert provider.client.api_key == 'api-key'
2522

2623

27-
def test_anthropic_provider_need_api_key(env: TestEnv) -> None:
28-
env.remove('ANTHROPIC_API_KEY')
29-
with pytest.raises(UserError, match=r'.*ANTHROPIC_API_KEY.*'):
30-
AnthropicProvider()
31-
32-
33-
def test_anthropic_provider_pass_http_client() -> None:
34-
http_client = httpx.AsyncClient()
35-
provider = AnthropicProvider(http_client=http_client, api_key='api-key')
36-
assert isinstance(provider.client, AsyncAnthropic)
37-
# Verify the http_client is being used by the AsyncAnthropic client
38-
assert provider.client._client == http_client # type: ignore[reportPrivateUsage]
39-
40-
4124
def test_anthropic_provider_pass_anthropic_client() -> None:
4225
anthropic_client = AsyncAnthropic(api_key='api-key')
4326
provider = AnthropicProvider(anthropic_client=anthropic_client)
4427
assert provider.client == anthropic_client
28+
bedrock_client = AsyncAnthropicBedrock(
29+
aws_secret_key='aws-secret-key',
30+
aws_access_key='aws-access-key',
31+
aws_region='us-west-2',
32+
aws_profile='default',
33+
aws_session_token='aws-session-token',
34+
)
35+
provider = AnthropicProvider(anthropic_client=bedrock_client)
36+
assert provider.client == bedrock_client
4537

4638

4739
def test_anthropic_provider_with_env_base_url(monkeypatch: pytest.MonkeyPatch) -> None:

0 commit comments

Comments
 (0)