Skip to content

Commit d1a7cda

Browse files
Anthropic streaming support (#684)
Co-authored-by: sydney-runkle <[email protected]>
1 parent 38e5b16 commit d1a7cda

File tree

2 files changed

+226
-26
lines changed

2 files changed

+226
-26
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 79 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
from __future__ import annotations as _annotations
22

3-
from collections.abc import AsyncIterator
3+
from collections.abc import AsyncIterable, AsyncIterator
44
from contextlib import asynccontextmanager
55
from dataclasses import dataclass, field
6+
from datetime import datetime, timezone
7+
from json import JSONDecodeError, loads as json_loads
68
from typing import Any, Literal, Union, cast, overload
79

810
from httpx import AsyncClient as AsyncHTTPClient
911
from typing_extensions import assert_never
1012

11-
from .. import usage
13+
from .. import UnexpectedModelBehavior, _utils, usage
1214
from .._utils import guard_tool_call_id as _guard_tool_call_id
1315
from ..messages import (
1416
ArgsDict,
1517
ModelMessage,
1618
ModelRequest,
1719
ModelResponse,
1820
ModelResponsePart,
21+
ModelResponseStreamEvent,
1922
RetryPromptPart,
2023
SystemPromptPart,
2124
TextPart,
@@ -38,11 +41,16 @@
3841
from anthropic.types import (
3942
Message as AnthropicMessage,
4043
MessageParam,
44+
RawContentBlockDeltaEvent,
45+
RawContentBlockStartEvent,
46+
RawContentBlockStopEvent,
4147
RawMessageDeltaEvent,
4248
RawMessageStartEvent,
49+
RawMessageStopEvent,
4350
RawMessageStreamEvent,
4451
TextBlock,
4552
TextBlockParam,
53+
TextDelta,
4654
ToolChoiceParam,
4755
ToolParam,
4856
ToolResultBlockParam,
@@ -234,24 +242,15 @@ def _process_response(self, response: AnthropicMessage) -> ModelResponse:
234242

235243
return ModelResponse(items, model_name=self.model_name)
236244

237-
@staticmethod
238-
async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
239-
"""TODO: Process a streamed response, and prepare a streaming response to return."""
240-
# We don't yet support streamed responses from Anthropic, so we raise an error here for now.
241-
# Streamed responses will be supported in a future release.
242-
243-
raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')
244-
245-
# Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamedResponse
246-
# depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
247-
# RawMessageStartEvent
248-
# RawMessageDeltaEvent
249-
# RawMessageStopEvent
250-
# RawContentBlockStartEvent
251-
# RawContentBlockDeltaEvent
252-
# RawContentBlockDeltaEvent
253-
#
254-
# We might refactor streaming internally before we implement this...
245+
async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
246+
peekable_response = _utils.PeekableAsyncStream(response)
247+
first_chunk = await peekable_response.peek()
248+
if isinstance(first_chunk, _utils.Unset):
249+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
250+
251+
# Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
252+
timestamp = datetime.now(tz=timezone.utc)
253+
return AnthropicStreamedResponse(_model_name=self.model_name, _response=peekable_response, _timestamp=timestamp)
255254

256255
@staticmethod
257256
def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
@@ -347,3 +346,63 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
347346
response_tokens=response_usage.output_tokens,
348347
total_tokens=(request_tokens or 0) + response_usage.output_tokens,
349348
)
349+
350+
351+
@dataclass
352+
class AnthropicStreamedResponse(StreamedResponse):
353+
"""Implementation of `StreamedResponse` for Anthropic models."""
354+
355+
_response: AsyncIterable[RawMessageStreamEvent]
356+
_timestamp: datetime
357+
358+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
359+
current_block: TextBlock | ToolUseBlock | None = None
360+
current_json: str = ''
361+
362+
async for event in self._response:
363+
self._usage += _map_usage(event)
364+
365+
if isinstance(event, RawContentBlockStartEvent):
366+
current_block = event.content_block
367+
if isinstance(current_block, TextBlock) and current_block.text:
368+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text)
369+
elif isinstance(current_block, ToolUseBlock):
370+
maybe_event = self._parts_manager.handle_tool_call_delta(
371+
vendor_part_id=current_block.id,
372+
tool_name=current_block.name,
373+
args=cast(dict[str, Any], current_block.input),
374+
tool_call_id=current_block.id,
375+
)
376+
if maybe_event is not None:
377+
yield maybe_event
378+
379+
elif isinstance(event, RawContentBlockDeltaEvent):
380+
if isinstance(event.delta, TextDelta):
381+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text)
382+
elif (
383+
current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, ToolUseBlock)
384+
):
385+
# Try to parse the JSON immediately, otherwise cache the value for later. This handles
386+
# cases where the JSON is not currently valid but will be valid once we stream more tokens.
387+
try:
388+
parsed_args = json_loads(current_json + event.delta.partial_json)
389+
current_json = ''
390+
except JSONDecodeError:
391+
current_json += event.delta.partial_json
392+
continue
393+
394+
# For tool calls, we need to handle partial JSON updates
395+
maybe_event = self._parts_manager.handle_tool_call_delta(
396+
vendor_part_id=current_block.id,
397+
tool_name='',
398+
args=parsed_args,
399+
tool_call_id=current_block.id,
400+
)
401+
if maybe_event is not None:
402+
yield maybe_event
403+
404+
elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)):
405+
current_block = None
406+
407+
def timestamp(self) -> datetime:
408+
return self._timestamp

tests/models/test_anthropic.py

Lines changed: 147 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import dataclass, field
55
from datetime import timezone
66
from functools import cached_property
7-
from typing import Any, cast
7+
from typing import Any, TypeVar, cast
88

99
import pytest
1010
from inline_snapshot import snapshot
@@ -25,16 +25,27 @@
2525
from pydantic_ai.settings import ModelSettings
2626

2727
from ..conftest import IsNow, try_import
28+
from .mock_async_stream import MockAsyncStream
2829

2930
with try_import() as imports_successful:
3031
from anthropic import NOT_GIVEN, AsyncAnthropic
3132
from anthropic.types import (
3233
ContentBlock,
34+
InputJSONDelta,
3335
Message as AnthropicMessage,
36+
MessageDeltaUsage,
37+
RawContentBlockDeltaEvent,
38+
RawContentBlockStartEvent,
39+
RawContentBlockStopEvent,
40+
RawMessageDeltaEvent,
41+
RawMessageStartEvent,
42+
RawMessageStopEvent,
43+
RawMessageStreamEvent,
3444
TextBlock,
3545
ToolUseBlock,
3646
Usage as AnthropicUsage,
3747
)
48+
from anthropic.types.raw_message_delta_event import Delta
3849

3950
from pydantic_ai.models.anthropic import AnthropicModel
4051

@@ -43,6 +54,9 @@
4354
pytest.mark.anyio,
4455
]
4556

57+
# Type variable for generic AsyncStream
58+
T = TypeVar('T')
59+
4660

4761
def test_init():
4862
m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar')
@@ -53,6 +67,7 @@ def test_init():
5367
@dataclass
5468
class MockAnthropic:
5569
messages_: AnthropicMessage | list[AnthropicMessage] | None = None
70+
stream: list[RawMessageStreamEvent] | list[list[RawMessageStreamEvent]] | None = None
5671
index = 0
5772
chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list)
5873

@@ -64,14 +79,31 @@ def messages(self) -> Any:
6479
def create_mock(cls, messages_: AnthropicMessage | list[AnthropicMessage]) -> AsyncAnthropic:
6580
return cast(AsyncAnthropic, cls(messages_=messages_))
6681

67-
async def messages_create(self, *_args: Any, **kwargs: Any) -> AnthropicMessage:
82+
@classmethod
83+
def create_stream_mock(
84+
cls, stream: list[RawMessageStreamEvent] | list[list[RawMessageStreamEvent]]
85+
) -> AsyncAnthropic:
86+
return cast(AsyncAnthropic, cls(stream=stream))
87+
88+
async def messages_create(
89+
self, *_args: Any, stream: bool = False, **kwargs: Any
90+
) -> AnthropicMessage | MockAsyncStream[RawMessageStreamEvent]:
6891
self.chat_completion_kwargs.append({k: v for k, v in kwargs.items() if v is not NOT_GIVEN})
6992

70-
assert self.messages_ is not None, '`messages` must be provided'
71-
if isinstance(self.messages_, list):
72-
response = self.messages_[self.index]
93+
if stream:
94+
assert self.stream is not None, 'you can only use `stream=True` if `stream` is provided'
95+
# noinspection PyUnresolvedReferences
96+
if isinstance(self.stream[0], list):
97+
indexed_stream = cast(list[RawMessageStreamEvent], self.stream[self.index])
98+
response = MockAsyncStream(iter(indexed_stream))
99+
else:
100+
response = MockAsyncStream(iter(cast(list[RawMessageStreamEvent], self.stream)))
73101
else:
74-
response = self.messages_
102+
assert self.messages_ is not None, '`messages` must be provided'
103+
if isinstance(self.messages_, list):
104+
response = self.messages_[self.index]
105+
else:
106+
response = self.messages_
75107
self.index += 1
76108
return response
77109

@@ -298,3 +330,112 @@ async def get_location(loc_name: str) -> str:
298330
assert get_mock_chat_completion_kwargs(mock_client)[0]['tool_choice']['disable_parallel_tool_use'] == (
299331
not parallel_tool_calls
300332
)
333+
334+
335+
async def test_stream_structured(allow_model_requests: None):
336+
"""Test streaming structured responses with Anthropic's API.
337+
338+
This test simulates how Anthropic streams tool calls:
339+
1. Message start
340+
2. Tool block start with initial data
341+
3. Tool block delta with additional data
342+
4. Tool block stop
343+
5. Update usage
344+
6. Message stop
345+
"""
346+
stream: list[RawMessageStreamEvent] = [
347+
RawMessageStartEvent(
348+
type='message_start',
349+
message=AnthropicMessage(
350+
id='msg_123',
351+
model='claude-3-5-haiku-latest',
352+
role='assistant',
353+
type='message',
354+
content=[],
355+
stop_reason=None,
356+
usage=AnthropicUsage(input_tokens=20, output_tokens=0),
357+
),
358+
),
359+
# Start tool block with initial data
360+
RawContentBlockStartEvent(
361+
type='content_block_start',
362+
index=0,
363+
content_block=ToolUseBlock(type='tool_use', id='tool_1', name='my_tool', input={'first': 'One'}),
364+
),
365+
# Add more data through an incomplete JSON delta
366+
RawContentBlockDeltaEvent(
367+
type='content_block_delta',
368+
index=0,
369+
delta=InputJSONDelta(type='input_json_delta', partial_json='{"second":'),
370+
),
371+
RawContentBlockDeltaEvent(
372+
type='content_block_delta',
373+
index=0,
374+
delta=InputJSONDelta(type='input_json_delta', partial_json='"Two"}'),
375+
),
376+
# Mark tool block as complete
377+
RawContentBlockStopEvent(type='content_block_stop', index=0),
378+
# Update the top-level message with usage
379+
RawMessageDeltaEvent(
380+
type='message_delta',
381+
delta=Delta(
382+
stop_reason='end_turn',
383+
),
384+
usage=MessageDeltaUsage(
385+
output_tokens=5,
386+
),
387+
),
388+
# Mark message as complete
389+
RawMessageStopEvent(type='message_stop'),
390+
]
391+
392+
done_stream: list[RawMessageStreamEvent] = [
393+
RawMessageStartEvent(
394+
type='message_start',
395+
message=AnthropicMessage(
396+
id='msg_123',
397+
model='claude-3-5-haiku-latest',
398+
role='assistant',
399+
type='message',
400+
content=[],
401+
stop_reason=None,
402+
usage=AnthropicUsage(input_tokens=0, output_tokens=0),
403+
),
404+
),
405+
# Text block with final data
406+
RawContentBlockStartEvent(
407+
type='content_block_start',
408+
index=0,
409+
content_block=TextBlock(type='text', text='FINAL_PAYLOAD'),
410+
),
411+
RawContentBlockStopEvent(type='content_block_stop', index=0),
412+
RawMessageStopEvent(type='message_stop'),
413+
]
414+
415+
mock_client = MockAnthropic.create_stream_mock([stream, done_stream])
416+
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
417+
agent = Agent(m)
418+
419+
tool_called = False
420+
421+
@agent.tool_plain
422+
async def my_tool(first: str, second: str) -> int:
423+
nonlocal tool_called
424+
tool_called = True
425+
return len(first) + len(second)
426+
427+
async with agent.run_stream('') as result:
428+
assert not result.is_complete
429+
chunks = [c async for c in result.stream(debounce_by=None)]
430+
431+
# The tool output doesn't echo any content to the stream, so we only get the final payload once when
432+
# the block starts and once when it ends.
433+
assert chunks == snapshot(
434+
[
435+
'FINAL_PAYLOAD',
436+
'FINAL_PAYLOAD',
437+
]
438+
)
439+
assert result.is_complete
440+
assert result.usage() == snapshot(Usage(requests=2, request_tokens=20, response_tokens=5, total_tokens=25))
441+
assert tool_called

0 commit comments

Comments
 (0)