4
4
from dataclasses import dataclass , field
5
5
from datetime import timezone
6
6
from functools import cached_property
7
- from typing import Any , cast
7
+ from typing import Any , TypeVar , cast
8
8
9
9
import pytest
10
10
from inline_snapshot import snapshot
25
25
from pydantic_ai .settings import ModelSettings
26
26
27
27
from ..conftest import IsNow , try_import
28
+ from .mock_async_stream import MockAsyncStream
28
29
29
30
with try_import () as imports_successful :
30
31
from anthropic import NOT_GIVEN , AsyncAnthropic
31
32
from anthropic .types import (
32
33
ContentBlock ,
34
+ InputJSONDelta ,
33
35
Message as AnthropicMessage ,
36
+ MessageDeltaUsage ,
37
+ RawContentBlockDeltaEvent ,
38
+ RawContentBlockStartEvent ,
39
+ RawContentBlockStopEvent ,
40
+ RawMessageDeltaEvent ,
41
+ RawMessageStartEvent ,
42
+ RawMessageStopEvent ,
43
+ RawMessageStreamEvent ,
34
44
TextBlock ,
35
45
ToolUseBlock ,
36
46
Usage as AnthropicUsage ,
37
47
)
48
+ from anthropic .types .raw_message_delta_event import Delta
38
49
39
50
from pydantic_ai .models .anthropic import AnthropicModel
40
51
43
54
pytest .mark .anyio ,
44
55
]
45
56
57
+ # Type variable for generic AsyncStream
58
+ T = TypeVar ('T' )
59
+
46
60
47
61
def test_init ():
48
62
m = AnthropicModel ('claude-3-5-haiku-latest' , api_key = 'foobar' )
@@ -53,6 +67,7 @@ def test_init():
53
67
@dataclass
54
68
class MockAnthropic :
55
69
messages_ : AnthropicMessage | list [AnthropicMessage ] | None = None
70
+ stream : list [RawMessageStreamEvent ] | list [list [RawMessageStreamEvent ]] | None = None
56
71
index = 0
57
72
chat_completion_kwargs : list [dict [str , Any ]] = field (default_factory = list )
58
73
@@ -64,14 +79,31 @@ def messages(self) -> Any:
64
79
def create_mock (cls , messages_ : AnthropicMessage | list [AnthropicMessage ]) -> AsyncAnthropic :
65
80
return cast (AsyncAnthropic , cls (messages_ = messages_ ))
66
81
67
- async def messages_create (self , * _args : Any , ** kwargs : Any ) -> AnthropicMessage :
82
+ @classmethod
83
+ def create_stream_mock (
84
+ cls , stream : list [RawMessageStreamEvent ] | list [list [RawMessageStreamEvent ]]
85
+ ) -> AsyncAnthropic :
86
+ return cast (AsyncAnthropic , cls (stream = stream ))
87
+
88
+ async def messages_create (
89
+ self , * _args : Any , stream : bool = False , ** kwargs : Any
90
+ ) -> AnthropicMessage | MockAsyncStream [RawMessageStreamEvent ]:
68
91
self .chat_completion_kwargs .append ({k : v for k , v in kwargs .items () if v is not NOT_GIVEN })
69
92
70
- assert self .messages_ is not None , '`messages` must be provided'
71
- if isinstance (self .messages_ , list ):
72
- response = self .messages_ [self .index ]
93
+ if stream :
94
+ assert self .stream is not None , 'you can only use `stream=True` if `stream` is provided'
95
+ # noinspection PyUnresolvedReferences
96
+ if isinstance (self .stream [0 ], list ):
97
+ indexed_stream = cast (list [RawMessageStreamEvent ], self .stream [self .index ])
98
+ response = MockAsyncStream (iter (indexed_stream ))
99
+ else :
100
+ response = MockAsyncStream (iter (cast (list [RawMessageStreamEvent ], self .stream )))
73
101
else :
74
- response = self .messages_
102
+ assert self .messages_ is not None , '`messages` must be provided'
103
+ if isinstance (self .messages_ , list ):
104
+ response = self .messages_ [self .index ]
105
+ else :
106
+ response = self .messages_
75
107
self .index += 1
76
108
return response
77
109
@@ -298,3 +330,112 @@ async def get_location(loc_name: str) -> str:
298
330
assert get_mock_chat_completion_kwargs (mock_client )[0 ]['tool_choice' ]['disable_parallel_tool_use' ] == (
299
331
not parallel_tool_calls
300
332
)
333
+
334
+
335
+ async def test_stream_structured (allow_model_requests : None ):
336
+ """Test streaming structured responses with Anthropic's API.
337
+
338
+ This test simulates how Anthropic streams tool calls:
339
+ 1. Message start
340
+ 2. Tool block start with initial data
341
+ 3. Tool block delta with additional data
342
+ 4. Tool block stop
343
+ 5. Update usage
344
+ 6. Message stop
345
+ """
346
+ stream : list [RawMessageStreamEvent ] = [
347
+ RawMessageStartEvent (
348
+ type = 'message_start' ,
349
+ message = AnthropicMessage (
350
+ id = 'msg_123' ,
351
+ model = 'claude-3-5-haiku-latest' ,
352
+ role = 'assistant' ,
353
+ type = 'message' ,
354
+ content = [],
355
+ stop_reason = None ,
356
+ usage = AnthropicUsage (input_tokens = 20 , output_tokens = 0 ),
357
+ ),
358
+ ),
359
+ # Start tool block with initial data
360
+ RawContentBlockStartEvent (
361
+ type = 'content_block_start' ,
362
+ index = 0 ,
363
+ content_block = ToolUseBlock (type = 'tool_use' , id = 'tool_1' , name = 'my_tool' , input = {'first' : 'One' }),
364
+ ),
365
+ # Add more data through an incomplete JSON delta
366
+ RawContentBlockDeltaEvent (
367
+ type = 'content_block_delta' ,
368
+ index = 0 ,
369
+ delta = InputJSONDelta (type = 'input_json_delta' , partial_json = '{"second":' ),
370
+ ),
371
+ RawContentBlockDeltaEvent (
372
+ type = 'content_block_delta' ,
373
+ index = 0 ,
374
+ delta = InputJSONDelta (type = 'input_json_delta' , partial_json = '"Two"}' ),
375
+ ),
376
+ # Mark tool block as complete
377
+ RawContentBlockStopEvent (type = 'content_block_stop' , index = 0 ),
378
+ # Update the top-level message with usage
379
+ RawMessageDeltaEvent (
380
+ type = 'message_delta' ,
381
+ delta = Delta (
382
+ stop_reason = 'end_turn' ,
383
+ ),
384
+ usage = MessageDeltaUsage (
385
+ output_tokens = 5 ,
386
+ ),
387
+ ),
388
+ # Mark message as complete
389
+ RawMessageStopEvent (type = 'message_stop' ),
390
+ ]
391
+
392
+ done_stream : list [RawMessageStreamEvent ] = [
393
+ RawMessageStartEvent (
394
+ type = 'message_start' ,
395
+ message = AnthropicMessage (
396
+ id = 'msg_123' ,
397
+ model = 'claude-3-5-haiku-latest' ,
398
+ role = 'assistant' ,
399
+ type = 'message' ,
400
+ content = [],
401
+ stop_reason = None ,
402
+ usage = AnthropicUsage (input_tokens = 0 , output_tokens = 0 ),
403
+ ),
404
+ ),
405
+ # Text block with final data
406
+ RawContentBlockStartEvent (
407
+ type = 'content_block_start' ,
408
+ index = 0 ,
409
+ content_block = TextBlock (type = 'text' , text = 'FINAL_PAYLOAD' ),
410
+ ),
411
+ RawContentBlockStopEvent (type = 'content_block_stop' , index = 0 ),
412
+ RawMessageStopEvent (type = 'message_stop' ),
413
+ ]
414
+
415
+ mock_client = MockAnthropic .create_stream_mock ([stream , done_stream ])
416
+ m = AnthropicModel ('claude-3-5-haiku-latest' , anthropic_client = mock_client )
417
+ agent = Agent (m )
418
+
419
+ tool_called = False
420
+
421
+ @agent .tool_plain
422
+ async def my_tool (first : str , second : str ) -> int :
423
+ nonlocal tool_called
424
+ tool_called = True
425
+ return len (first ) + len (second )
426
+
427
+ async with agent .run_stream ('' ) as result :
428
+ assert not result .is_complete
429
+ chunks = [c async for c in result .stream (debounce_by = None )]
430
+
431
+ # The tool output doesn't echo any content to the stream, so we only get the final payload once when
432
+ # the block starts and once when it ends.
433
+ assert chunks == snapshot (
434
+ [
435
+ 'FINAL_PAYLOAD' ,
436
+ 'FINAL_PAYLOAD' ,
437
+ ]
438
+ )
439
+ assert result .is_complete
440
+ assert result .usage () == snapshot (Usage (requests = 2 , request_tokens = 20 , response_tokens = 5 , total_tokens = 25 ))
441
+ assert tool_called
0 commit comments