Skip to content

Commit 4918ba9

Browse files
authored
Handle streaming thinking signature deltas from Bedrock Converse API (#2785)
1 parent 187996e commit 4918ba9

File tree

3 files changed

+134
-2691
lines changed

3 files changed

+134
-2691
lines changed

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import functools
44
import typing
5-
import warnings
65
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
76
from contextlib import asynccontextmanager
87
from dataclasses import dataclass, field
@@ -601,7 +600,7 @@ class BedrockStreamedResponse(StreamedResponse):
601600
_provider_name: str
602601
_timestamp: datetime = field(default_factory=_utils.now_utc)
603602

604-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
603+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
605604
"""Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
606605
607606
This method should be implemented by subclasses to translate the vendor-specific stream of events into
@@ -638,18 +637,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
638637
index = content_block_delta['contentBlockIndex']
639638
delta = content_block_delta['delta']
640639
if 'reasoningContent' in delta:
641-
if text := delta['reasoningContent'].get('text'):
642-
yield self._parts_manager.handle_thinking_delta(
643-
vendor_part_id=index,
644-
content=text,
645-
signature=delta['reasoningContent'].get('signature'),
646-
)
647-
else: # pragma: no cover
648-
warnings.warn(
649-
f'Only text reasoning content is supported yet, but you got {delta["reasoningContent"]}. '
650-
'Please report this to the maintainers.',
651-
UserWarning,
652-
)
640+
yield self._parts_manager.handle_thinking_delta(
641+
vendor_part_id=index,
642+
content=delta['reasoningContent'].get('text'),
643+
signature=delta['reasoningContent'].get('signature'),
644+
)
653645
if 'text' in delta:
654646
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
655647
if maybe_event is not None: # pragma: no branch

0 commit comments

Comments
 (0)