7
7
from datetime import timezone
8
8
from functools import cached_property
9
9
from typing import Any , TypeVar , Union , cast
10
+ from unittest .mock import patch
10
11
11
12
import httpx
12
13
import pytest
53
54
from anthropic .types .raw_message_delta_event import Delta
54
55
55
56
from pydantic_ai .models .anthropic import AnthropicModel , AnthropicModelSettings
57
+ from pydantic_ai .providers .anthropic import AnthropicProvider
56
58
57
59
# note: we use Union here so that casting works with Python 3.9
58
60
MockAnthropicMessage = Union [AnthropicMessage , Exception ]
68
70
69
71
70
72
def test_init ():
71
- m = AnthropicModel ('claude-3-5-haiku-latest' , api_key = 'foobar' )
73
+ m = AnthropicModel ('claude-3-5-haiku-latest' , provider = AnthropicProvider ( api_key = 'foobar' ) )
72
74
assert m .client .api_key == 'foobar'
73
75
assert m .model_name == 'claude-3-5-haiku-latest'
74
76
assert m .system == 'anthropic'
@@ -81,6 +83,7 @@ class MockAnthropic:
81
83
stream : Sequence [MockRawMessageStreamEvent ] | Sequence [Sequence [MockRawMessageStreamEvent ]] | None = None
82
84
index = 0
83
85
chat_completion_kwargs : list [dict [str , Any ]] = field (default_factory = list )
86
+ base_url : str | None = None
84
87
85
88
@cached_property
86
89
def messages (self ) -> Any :
@@ -134,7 +137,7 @@ def completion_message(content: list[ContentBlock], usage: AnthropicUsage) -> An
134
137
async def test_sync_request_text_response (allow_model_requests : None ):
135
138
c = completion_message ([TextBlock (text = 'world' , type = 'text' )], AnthropicUsage (input_tokens = 5 , output_tokens = 10 ))
136
139
mock_client = MockAnthropic .create_mock (c )
137
- m = AnthropicModel ('claude-3-5-haiku-latest' , anthropic_client = mock_client )
140
+ m = AnthropicModel ('claude-3-5-haiku-latest' , provider = AnthropicProvider ( anthropic_client = mock_client ) )
138
141
agent = Agent (m )
139
142
140
143
result = await agent .run ('hello' )
@@ -171,7 +174,7 @@ async def test_async_request_text_response(allow_model_requests: None):
171
174
usage = AnthropicUsage (input_tokens = 3 , output_tokens = 5 ),
172
175
)
173
176
mock_client = MockAnthropic .create_mock (c )
174
- m = AnthropicModel ('claude-3-5-haiku-latest' , anthropic_client = mock_client )
177
+ m = AnthropicModel ('claude-3-5-haiku-latest' , provider = AnthropicProvider ( anthropic_client = mock_client ) )
175
178
agent = Agent (m )
176
179
177
180
result = await agent .run ('hello' )
@@ -185,7 +188,7 @@ async def test_request_structured_response(allow_model_requests: None):
185
188
usage = AnthropicUsage (input_tokens = 3 , output_tokens = 5 ),
186
189
)
187
190
mock_client = MockAnthropic .create_mock (c )
188
- m = AnthropicModel ('claude-3-5-haiku-latest' , anthropic_client = mock_client )
191
+ m = AnthropicModel ('claude-3-5-haiku-latest' , provider = AnthropicProvider ( anthropic_client = mock_client ) )
189
192
agent = Agent (m , result_type = list [int ])
190
193
191
194
result = await agent .run ('hello' )
@@ -235,7 +238,7 @@ async def test_request_tool_call(allow_model_requests: None):
235
238
]
236
239
237
240
mock_client = MockAnthropic .create_mock (responses )
238
- m = AnthropicModel ('claude-3-5-haiku-latest' , anthropic_client = mock_client )
241
+ m = AnthropicModel ('claude-3-5-haiku-latest' , provider = AnthropicProvider ( anthropic_client = mock_client ) )
239
242
agent = Agent (m , system_prompt = 'this is the system prompt' )
240
243
241
244
@agent .tool_plain
@@ -327,7 +330,7 @@ async def test_parallel_tool_calls(allow_model_requests: None, parallel_tool_cal
327
330
]
328
331
329
332
mock_client = MockAnthropic .create_mock (responses )
330
- m = AnthropicModel ('claude-3-5-haiku-latest' , anthropic_client = mock_client )
333
+ m = AnthropicModel ('claude-3-5-haiku-latest' , provider = AnthropicProvider ( anthropic_client = mock_client ) )
331
334
agent = Agent (m , model_settings = ModelSettings (parallel_tool_calls = parallel_tool_calls ))
332
335
333
336
@agent .tool_plain
@@ -366,7 +369,7 @@ async def retrieve_entity_info(name: str) -> str:
366
369
# However, we do want to use the environment variable if present when rewriting VCR cassettes.
367
370
api_key = os .environ .get ('ANTHROPIC_API_KEY' , 'mock-value' )
368
371
agent = Agent (
369
- AnthropicModel ('claude-3-5-haiku-latest' , api_key = api_key ),
372
+ AnthropicModel ('claude-3-5-haiku-latest' , provider = AnthropicProvider ( api_key = api_key ) ),
370
373
system_prompt = system_prompt ,
371
374
tools = [retrieve_entity_info ],
372
375
)
@@ -436,7 +439,7 @@ async def retrieve_entity_info(name: str) -> str:
436
439
async def test_anthropic_specific_metadata (allow_model_requests : None ) -> None :
437
440
c = completion_message ([TextBlock (text = 'world' , type = 'text' )], AnthropicUsage (input_tokens = 5 , output_tokens = 10 ))
438
441
mock_client = MockAnthropic .create_mock (c )
439
- m = AnthropicModel ('claude-3-5-haiku-latest' , anthropic_client = mock_client )
442
+ m = AnthropicModel ('claude-3-5-haiku-latest' , provider = AnthropicProvider ( anthropic_client = mock_client ) )
440
443
agent = Agent (m )
441
444
442
445
result = await agent .run ('hello' , model_settings = AnthropicModelSettings (anthropic_metadata = {'user_id' : '123' }))
@@ -525,7 +528,7 @@ async def test_stream_structured(allow_model_requests: None):
525
528
]
526
529
527
530
mock_client = MockAnthropic .create_stream_mock ([stream , done_stream ])
528
- m = AnthropicModel ('claude-3-5-haiku-latest' , anthropic_client = mock_client )
531
+ m = AnthropicModel ('claude-3-5-haiku-latest' , provider = AnthropicProvider ( anthropic_client = mock_client ) )
529
532
agent = Agent (m )
530
533
531
534
tool_called = False
@@ -555,7 +558,7 @@ async def my_tool(first: str, second: str) -> int:
555
558
556
559
@pytest .mark .vcr ()
557
560
async def test_image_url_input (allow_model_requests : None , anthropic_api_key : str ):
558
- m = AnthropicModel ('claude-3-5-haiku-latest' , api_key = anthropic_api_key )
561
+ m = AnthropicModel ('claude-3-5-haiku-latest' , provider = AnthropicProvider ( api_key = anthropic_api_key ) )
559
562
agent = Agent (m )
560
563
561
564
result = await agent .run (
@@ -573,7 +576,7 @@ async def test_image_url_input(allow_model_requests: None, anthropic_api_key: st
573
576
574
577
@pytest .mark .vcr ()
575
578
async def test_image_url_input_invalid_mime_type (allow_model_requests : None , anthropic_api_key : str ):
576
- m = AnthropicModel ('claude-3-5-haiku-latest' , api_key = anthropic_api_key )
579
+ m = AnthropicModel ('claude-3-5-haiku-latest' , provider = AnthropicProvider ( api_key = anthropic_api_key ) )
577
580
agent = Agent (m )
578
581
579
582
result = await agent .run (
@@ -593,7 +596,7 @@ async def test_image_url_input_invalid_mime_type(allow_model_requests: None, ant
593
596
async def test_audio_as_binary_content_input (allow_model_requests : None , media_type : str ):
594
597
c = completion_message ([TextBlock (text = 'world' , type = 'text' )], AnthropicUsage (input_tokens = 5 , output_tokens = 10 ))
595
598
mock_client = MockAnthropic .create_mock (c )
596
- m = AnthropicModel ('claude-3-5-haiku-latest' , anthropic_client = mock_client )
599
+ m = AnthropicModel ('claude-3-5-haiku-latest' , provider = AnthropicProvider ( anthropic_client = mock_client ) )
597
600
agent = Agent (m )
598
601
599
602
base64_content = b'//uQZ'
@@ -610,7 +613,7 @@ def test_model_status_error(allow_model_requests: None) -> None:
610
613
body = {'error' : 'test error' },
611
614
)
612
615
)
613
- m = AnthropicModel ('claude-3-5-sonnet-latest' , anthropic_client = mock_client )
616
+ m = AnthropicModel ('claude-3-5-sonnet-latest' , provider = AnthropicProvider ( anthropic_client = mock_client ) )
614
617
agent = Agent (m )
615
618
with pytest .raises (ModelHTTPError ) as exc_info :
616
619
agent .run_sync ('hello' )
@@ -623,7 +626,7 @@ def test_model_status_error(allow_model_requests: None) -> None:
623
626
async def test_document_binary_content_input (
624
627
allow_model_requests : None , anthropic_api_key : str , document_content : BinaryContent
625
628
):
626
- m = AnthropicModel ('claude-3-5-sonnet-latest' , api_key = anthropic_api_key )
629
+ m = AnthropicModel ('claude-3-5-sonnet-latest' , provider = AnthropicProvider ( api_key = anthropic_api_key ) )
627
630
agent = Agent (m )
628
631
629
632
result = await agent .run (['What is the main content on this document?' , document_content ])
@@ -634,7 +637,7 @@ async def test_document_binary_content_input(
634
637
635
638
@pytest .mark .vcr ()
636
639
async def test_document_url_input (allow_model_requests : None , anthropic_api_key : str ):
637
- m = AnthropicModel ('claude-3-5-sonnet-latest' , api_key = anthropic_api_key )
640
+ m = AnthropicModel ('claude-3-5-sonnet-latest' , provider = AnthropicProvider ( api_key = anthropic_api_key ) )
638
641
agent = Agent (m )
639
642
640
643
document_url = DocumentUrl (url = 'https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf' )
@@ -647,7 +650,7 @@ async def test_document_url_input(allow_model_requests: None, anthropic_api_key:
647
650
648
651
@pytest .mark .vcr ()
649
652
async def test_text_document_url_input (allow_model_requests : None , anthropic_api_key : str ):
650
- m = AnthropicModel ('claude-3-5-sonnet-latest' , api_key = anthropic_api_key )
653
+ m = AnthropicModel ('claude-3-5-sonnet-latest' , provider = AnthropicProvider ( api_key = anthropic_api_key ) )
651
654
agent = Agent (m )
652
655
653
656
text_document_url = DocumentUrl (url = 'https://example-files.online-convert.com/document/txt/example.txt' )
@@ -668,3 +671,17 @@ async def test_text_document_url_input(allow_model_requests: None, anthropic_api
668
671
669
672
The document is formatted as a test file with metadata including its purpose, file type, and version. It also includes attribution information indicating the content is from Wikipedia and is licensed under Attribution-ShareAlike 4.0.\
670
673
""" )
674
+
675
+
676
+ def test_init_with_provider ():
677
+ provider = AnthropicProvider (api_key = 'api-key' )
678
+ model = AnthropicModel ('claude-3-opus-latest' , provider = provider )
679
+ assert model .model_name == 'claude-3-opus-latest'
680
+ assert model .client == provider .client
681
+
682
+
683
+ def test_init_with_provider_string ():
684
+ with patch .dict (os .environ , {'ANTHROPIC_API_KEY' : 'env-api-key' }, clear = False ):
685
+ model = AnthropicModel ('claude-3-opus-latest' , provider = 'anthropic' )
686
+ assert model .model_name == 'claude-3-opus-latest'
687
+ assert model .client is not None
0 commit comments