Skip to content

Commit 8ddd05c

Browse files
hrahmadi71Kludex
andauthored
Add Anthropic provider classes (#1120)
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent 5579347 commit 8ddd05c

File tree

7 files changed

+235
-23
lines changed

7 files changed

+235
-23
lines changed

docs/models.md

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,38 @@ agent = Agent(model)
172172
...
173173
```
174174

175-
### `api_key` argument
175+
### `provider` argument
176+
177+
You can provide a custom [`Provider`][pydantic_ai.providers.Provider] via the [`provider` argument][pydantic_ai.models.anthropic.AnthropicModel.__init__]:
178+
179+
```py title="anthropic_model_provider.py"
180+
from pydantic_ai import Agent
181+
from pydantic_ai.models.anthropic import AnthropicModel
182+
from pydantic_ai.providers.anthropic import AnthropicProvider
183+
184+
model = AnthropicModel(
185+
'claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key='your-api-key')
186+
)
187+
agent = Agent(model)
188+
...
189+
```
190+
191+
### Custom HTTP Client
176192

177-
If you don't want to or can't set the environment variable, you can pass it at runtime via the [`api_key` argument][pydantic_ai.models.anthropic.AnthropicModel.__init__]:
193+
You can customize the `AnthropicProvider` with a custom `httpx.AsyncClient`:
194+
195+
```py title="anthropic_model_custom_provider.py"
196+
from httpx import AsyncClient
178197

179-
```py title="anthropic_model_api_key.py"
180198
from pydantic_ai import Agent
181199
from pydantic_ai.models.anthropic import AnthropicModel
200+
from pydantic_ai.providers.anthropic import AnthropicProvider
182201

183-
model = AnthropicModel('claude-3-5-sonnet-latest', api_key='your-api-key')
202+
custom_http_client = AsyncClient(timeout=30)
203+
model = AnthropicModel(
204+
'claude-3-5-sonnet-latest',
205+
provider=AnthropicProvider(api_key='your-api-key', http_client=custom_http_client),
206+
)
184207
agent = Agent(model)
185208
...
186209
```

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from anthropic.types import DocumentBlockParam
1313
from httpx import AsyncClient as AsyncHTTPClient
14-
from typing_extensions import assert_never
14+
from typing_extensions import assert_never, deprecated
1515

1616
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
1717
from .._utils import guard_tool_call_id as _guard_tool_call_id
@@ -31,6 +31,7 @@
3131
ToolReturnPart,
3232
UserPromptPart,
3333
)
34+
from ..providers import Provider, infer_provider
3435
from ..settings import ModelSettings
3536
from ..tools import ToolDefinition
3637
from . import Model, ModelRequestParameters, StreamedResponse, cached_async_http_client, check_allow_model_requests
@@ -111,10 +112,31 @@ class AnthropicModel(Model):
111112
_model_name: AnthropicModelName = field(repr=False)
112113
_system: str = field(default='anthropic', repr=False)
113114

115+
@overload
116+
def __init__(
117+
self,
118+
model_name: AnthropicModelName,
119+
*,
120+
provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
121+
) -> None: ...
122+
123+
@deprecated('Use the `provider` parameter instead of `api_key`, `anthropic_client`, and `http_client`.')
124+
@overload
125+
def __init__(
126+
self,
127+
model_name: AnthropicModelName,
128+
*,
129+
provider: None = None,
130+
api_key: str | None = None,
131+
anthropic_client: AsyncAnthropic | None = None,
132+
http_client: AsyncHTTPClient | None = None,
133+
) -> None: ...
134+
114135
def __init__(
115136
self,
116137
model_name: AnthropicModelName,
117138
*,
139+
provider: Literal['anthropic'] | Provider[AsyncAnthropic] | None = None,
118140
api_key: str | None = None,
119141
anthropic_client: AsyncAnthropic | None = None,
120142
http_client: AsyncHTTPClient | None = None,
@@ -124,6 +146,8 @@ def __init__(
124146
Args:
125147
model_name: The name of the Anthropic model to use. List of model names available
126148
[here](https://docs.anthropic.com/en/docs/about-claude/models).
149+
provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
150+
instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
127151
api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
128152
will be used if available.
129153
anthropic_client: An existing
@@ -132,7 +156,12 @@ def __init__(
132156
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
133157
"""
134158
self._model_name = model_name
135-
if anthropic_client is not None:
159+
160+
if provider is not None:
161+
if isinstance(provider, str):
162+
provider = infer_provider(provider)
163+
self.client = provider.client
164+
elif anthropic_client is not None:
136165
assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
137166
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
138167
self.client = anthropic_client

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,9 @@ def infer_provider(provider: str) -> Provider[Any]:
6969
from .groq import GroqProvider
7070

7171
return GroqProvider()
72+
elif provider == 'anthropic': # pragma: no cover
73+
from .anthropic import AnthropicProvider
74+
75+
return AnthropicProvider()
7276
else: # pragma: no cover
7377
raise ValueError(f'Unknown provider: {provider}')
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from __future__ import annotations as _annotations
2+
3+
import os
4+
from typing import overload
5+
6+
import httpx
7+
8+
from pydantic_ai.models import cached_async_http_client
9+
10+
try:
11+
from anthropic import AsyncAnthropic
12+
except ImportError as _import_error: # pragma: no cover
13+
raise ImportError(
14+
'Please install the `anthropic` package to use the Anthropic provider, '
15+
"you can use the `anthropic` optional group — `pip install 'pydantic-ai-slim[anthropic]'`"
16+
) from _import_error
17+
18+
19+
from . import Provider
20+
21+
22+
class AnthropicProvider(Provider[AsyncAnthropic]):
23+
"""Provider for Anthropic API."""
24+
25+
@property
26+
def name(self) -> str:
27+
return 'anthropic'
28+
29+
@property
30+
def base_url(self) -> str:
31+
return str(self._client.base_url)
32+
33+
@property
34+
def client(self) -> AsyncAnthropic:
35+
return self._client
36+
37+
@overload
38+
def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ...
39+
40+
@overload
41+
def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ...
42+
43+
def __init__(
44+
self,
45+
*,
46+
api_key: str | None = None,
47+
anthropic_client: AsyncAnthropic | None = None,
48+
http_client: httpx.AsyncClient | None = None,
49+
) -> None:
50+
"""Create a new Anthropic provider.
51+
52+
Args:
53+
api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
54+
will be used if available.
55+
anthropic_client: An existing [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python)
56+
client to use. If provided, the `api_key` and `http_client` arguments will be ignored.
57+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
58+
"""
59+
if anthropic_client is not None:
60+
assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
61+
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
62+
self._client = anthropic_client
63+
else:
64+
api_key = api_key or os.environ.get('ANTHROPIC_API_KEY')
65+
if api_key is None:
66+
raise ValueError(
67+
'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`'
68+
'to use the Anthropic provider.'
69+
)
70+
71+
if http_client is not None:
72+
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
73+
else:
74+
self._client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())

tests/models/test_anthropic.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from datetime import timezone
88
from functools import cached_property
99
from typing import Any, TypeVar, Union, cast
10+
from unittest.mock import patch
1011

1112
import httpx
1213
import pytest
@@ -53,6 +54,7 @@
5354
from anthropic.types.raw_message_delta_event import Delta
5455

5556
from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
57+
from pydantic_ai.providers.anthropic import AnthropicProvider
5658

5759
# note: we use Union here so that casting works with Python 3.9
5860
MockAnthropicMessage = Union[AnthropicMessage, Exception]
@@ -68,7 +70,7 @@
6870

6971

7072
def test_init():
71-
m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar')
73+
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key='foobar'))
7274
assert m.client.api_key == 'foobar'
7375
assert m.model_name == 'claude-3-5-haiku-latest'
7476
assert m.system == 'anthropic'
@@ -81,6 +83,7 @@ class MockAnthropic:
8183
stream: Sequence[MockRawMessageStreamEvent] | Sequence[Sequence[MockRawMessageStreamEvent]] | None = None
8284
index = 0
8385
chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list)
86+
base_url: str | None = None
8487

8588
@cached_property
8689
def messages(self) -> Any:
@@ -134,7 +137,7 @@ def completion_message(content: list[ContentBlock], usage: AnthropicUsage) -> An
134137
async def test_sync_request_text_response(allow_model_requests: None):
135138
c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
136139
mock_client = MockAnthropic.create_mock(c)
137-
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
140+
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
138141
agent = Agent(m)
139142

140143
result = await agent.run('hello')
@@ -171,7 +174,7 @@ async def test_async_request_text_response(allow_model_requests: None):
171174
usage=AnthropicUsage(input_tokens=3, output_tokens=5),
172175
)
173176
mock_client = MockAnthropic.create_mock(c)
174-
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
177+
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
175178
agent = Agent(m)
176179

177180
result = await agent.run('hello')
@@ -185,7 +188,7 @@ async def test_request_structured_response(allow_model_requests: None):
185188
usage=AnthropicUsage(input_tokens=3, output_tokens=5),
186189
)
187190
mock_client = MockAnthropic.create_mock(c)
188-
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
191+
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
189192
agent = Agent(m, result_type=list[int])
190193

191194
result = await agent.run('hello')
@@ -235,7 +238,7 @@ async def test_request_tool_call(allow_model_requests: None):
235238
]
236239

237240
mock_client = MockAnthropic.create_mock(responses)
238-
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
241+
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
239242
agent = Agent(m, system_prompt='this is the system prompt')
240243

241244
@agent.tool_plain
@@ -327,7 +330,7 @@ async def test_parallel_tool_calls(allow_model_requests: None, parallel_tool_cal
327330
]
328331

329332
mock_client = MockAnthropic.create_mock(responses)
330-
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
333+
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
331334
agent = Agent(m, model_settings=ModelSettings(parallel_tool_calls=parallel_tool_calls))
332335

333336
@agent.tool_plain
@@ -366,7 +369,7 @@ async def retrieve_entity_info(name: str) -> str:
366369
# However, we do want to use the environment variable if present when rewriting VCR cassettes.
367370
api_key = os.environ.get('ANTHROPIC_API_KEY', 'mock-value')
368371
agent = Agent(
369-
AnthropicModel('claude-3-5-haiku-latest', api_key=api_key),
372+
AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key=api_key)),
370373
system_prompt=system_prompt,
371374
tools=[retrieve_entity_info],
372375
)
@@ -436,7 +439,7 @@ async def retrieve_entity_info(name: str) -> str:
436439
async def test_anthropic_specific_metadata(allow_model_requests: None) -> None:
437440
c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
438441
mock_client = MockAnthropic.create_mock(c)
439-
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
442+
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
440443
agent = Agent(m)
441444

442445
result = await agent.run('hello', model_settings=AnthropicModelSettings(anthropic_metadata={'user_id': '123'}))
@@ -525,7 +528,7 @@ async def test_stream_structured(allow_model_requests: None):
525528
]
526529

527530
mock_client = MockAnthropic.create_stream_mock([stream, done_stream])
528-
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
531+
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
529532
agent = Agent(m)
530533

531534
tool_called = False
@@ -555,7 +558,7 @@ async def my_tool(first: str, second: str) -> int:
555558

556559
@pytest.mark.vcr()
557560
async def test_image_url_input(allow_model_requests: None, anthropic_api_key: str):
558-
m = AnthropicModel('claude-3-5-haiku-latest', api_key=anthropic_api_key)
561+
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
559562
agent = Agent(m)
560563

561564
result = await agent.run(
@@ -573,7 +576,7 @@ async def test_image_url_input(allow_model_requests: None, anthropic_api_key: st
573576

574577
@pytest.mark.vcr()
575578
async def test_image_url_input_invalid_mime_type(allow_model_requests: None, anthropic_api_key: str):
576-
m = AnthropicModel('claude-3-5-haiku-latest', api_key=anthropic_api_key)
579+
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
577580
agent = Agent(m)
578581

579582
result = await agent.run(
@@ -593,7 +596,7 @@ async def test_image_url_input_invalid_mime_type(allow_model_requests: None, ant
593596
async def test_audio_as_binary_content_input(allow_model_requests: None, media_type: str):
594597
c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
595598
mock_client = MockAnthropic.create_mock(c)
596-
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
599+
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
597600
agent = Agent(m)
598601

599602
base64_content = b'//uQZ'
@@ -610,7 +613,7 @@ def test_model_status_error(allow_model_requests: None) -> None:
610613
body={'error': 'test error'},
611614
)
612615
)
613-
m = AnthropicModel('claude-3-5-sonnet-latest', anthropic_client=mock_client)
616+
m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(anthropic_client=mock_client))
614617
agent = Agent(m)
615618
with pytest.raises(ModelHTTPError) as exc_info:
616619
agent.run_sync('hello')
@@ -623,7 +626,7 @@ def test_model_status_error(allow_model_requests: None) -> None:
623626
async def test_document_binary_content_input(
624627
allow_model_requests: None, anthropic_api_key: str, document_content: BinaryContent
625628
):
626-
m = AnthropicModel('claude-3-5-sonnet-latest', api_key=anthropic_api_key)
629+
m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
627630
agent = Agent(m)
628631

629632
result = await agent.run(['What is the main content on this document?', document_content])
@@ -634,7 +637,7 @@ async def test_document_binary_content_input(
634637

635638
@pytest.mark.vcr()
636639
async def test_document_url_input(allow_model_requests: None, anthropic_api_key: str):
637-
m = AnthropicModel('claude-3-5-sonnet-latest', api_key=anthropic_api_key)
640+
m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
638641
agent = Agent(m)
639642

640643
document_url = DocumentUrl(url='https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf')
@@ -647,7 +650,7 @@ async def test_document_url_input(allow_model_requests: None, anthropic_api_key:
647650

648651
@pytest.mark.vcr()
649652
async def test_text_document_url_input(allow_model_requests: None, anthropic_api_key: str):
650-
m = AnthropicModel('claude-3-5-sonnet-latest', api_key=anthropic_api_key)
653+
m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
651654
agent = Agent(m)
652655

653656
text_document_url = DocumentUrl(url='https://example-files.online-convert.com/document/txt/example.txt')
@@ -668,3 +671,17 @@ async def test_text_document_url_input(allow_model_requests: None, anthropic_api
668671
669672
The document is formatted as a test file with metadata including its purpose, file type, and version. It also includes attribution information indicating the content is from Wikipedia and is licensed under Attribution-ShareAlike 4.0.\
670673
""")
674+
675+
676+
def test_init_with_provider():
677+
provider = AnthropicProvider(api_key='api-key')
678+
model = AnthropicModel('claude-3-opus-latest', provider=provider)
679+
assert model.model_name == 'claude-3-opus-latest'
680+
assert model.client == provider.client
681+
682+
683+
def test_init_with_provider_string():
684+
with patch.dict(os.environ, {'ANTHROPIC_API_KEY': 'env-api-key'}, clear=False):
685+
model = AnthropicModel('claude-3-opus-latest', provider='anthropic')
686+
assert model.model_name == 'claude-3-opus-latest'
687+
assert model.client is not None

0 commit comments

Comments
 (0)