Skip to content

Commit 915135f

Browse files
authored
Send ThinkingParts back to Anthropic used through Bedrock (#2454)
1 parent 1ad1369 commit 915135f

File tree

4 files changed

+167
-9
lines changed

4 files changed

+167
-9
lines changed

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
MessageUnionTypeDef,
6060
PerformanceConfigurationTypeDef,
6161
PromptVariableValuesTypeDef,
62+
ReasoningContentBlockOutputTypeDef,
63+
ReasoningTextBlockTypeDef,
6264
SystemContentBlockTypeDef,
6365
ToolChoiceTypeDef,
6466
ToolConfigurationTypeDef,
@@ -276,9 +278,10 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes
276278
if reasoning_content := item.get('reasoningContent'):
277279
reasoning_text = reasoning_content.get('reasoningText')
278280
if reasoning_text: # pragma: no branch
279-
thinking_part = ThinkingPart(content=reasoning_text['text'])
280-
if reasoning_signature := reasoning_text.get('signature'):
281-
thinking_part.signature = reasoning_signature
281+
thinking_part = ThinkingPart(
282+
content=reasoning_text['text'],
283+
signature=reasoning_text.get('signature'),
284+
)
282285
items.append(thinking_part)
283286
if text := item.get('text'):
284287
items.append(TextPart(content=text))
@@ -462,8 +465,19 @@ async def _map_messages( # noqa: C901
462465
if isinstance(item, TextPart):
463466
content.append({'text': item.content})
464467
elif isinstance(item, ThinkingPart):
465-
# NOTE: We don't pass the thinking part to Bedrock since it raises an error.
466-
pass
468+
if BedrockModelProfile.from_profile(self.profile).bedrock_send_back_thinking_parts:
469+
reasoning_text: ReasoningTextBlockTypeDef = {
470+
'text': item.content,
471+
}
472+
if item.signature:
473+
reasoning_text['signature'] = item.signature
474+
reasoning_content: ReasoningContentBlockOutputTypeDef = {
475+
'reasoningText': reasoning_text,
476+
}
477+
content.append({'reasoningContent': reasoning_content})
478+
else:
479+
# NOTE: We don't pass the thinking part to Bedrock for models other than Claude since it raises an error.
480+
pass
467481
else:
468482
assert isinstance(item, ToolCallPart)
469483
content.append(self._map_tool_call(item))
@@ -610,7 +624,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
610624
delta = chunk['contentBlockDelta']['delta']
611625
if 'reasoningContent' in delta:
612626
if text := delta['reasoningContent'].get('text'):
613-
yield self._parts_manager.handle_thinking_delta(vendor_part_id=index, content=text)
627+
yield self._parts_manager.handle_thinking_delta(
628+
vendor_part_id=index,
629+
content=text,
630+
signature=delta['reasoningContent'].get('signature'),
631+
)
614632
else: # pragma: no cover
615633
warnings.warn(
616634
f'Only text reasoning content is supported yet, but you got {delta["reasoningContent"]}. '

pydantic_ai_slim/pydantic_ai/providers/bedrock.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class BedrockModelProfile(ModelProfile):
3636

3737
bedrock_supports_tool_choice: bool = True
3838
bedrock_tool_result_format: Literal['text', 'json'] = 'text'
39+
bedrock_send_back_thinking_parts: bool = False
3940

4041

4142
class BedrockProvider(Provider[BaseClient]):
@@ -55,9 +56,9 @@ def client(self) -> BaseClient:
5556

5657
def model_profile(self, model_name: str) -> ModelProfile | None:
5758
provider_to_profile: dict[str, Callable[[str], ModelProfile | None]] = {
58-
'anthropic': lambda model_name: BedrockModelProfile(bedrock_supports_tool_choice=False).update(
59-
anthropic_model_profile(model_name)
60-
),
59+
'anthropic': lambda model_name: BedrockModelProfile(
60+
bedrock_supports_tool_choice=False, bedrock_send_back_thinking_parts=True
61+
).update(anthropic_model_profile(model_name)),
6162
'mistral': lambda model_name: BedrockModelProfile(bedrock_tool_result_format='json').update(
6263
mistral_model_profile(model_name)
6364
),
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
interactions:
2+
- request:
3+
body: '{"messages": [{"role": "user", "content": [{"text": "What is the largest city in the user country?"}]}], "system":
4+
[], "inferenceConfig": {}, "toolConfig": {"tools": [{"toolSpec": {"name": "get_user_country", "inputSchema": {"json":
5+
{"additionalProperties": false, "properties": {}, "type": "object"}}}}]}, "additionalModelRequestFields": {"thinking":
6+
{"type": "enabled", "budget_tokens": 1024}}}'
7+
headers:
8+
amz-sdk-invocation-id:
9+
- !!binary |
10+
MmM5YzRmZDctMmNlZS00Yzk2LWIwZWMtZjMxN2NkZDEwYmM5
11+
amz-sdk-request:
12+
- !!binary |
13+
YXR0ZW1wdD0x
14+
content-length:
15+
- '396'
16+
content-type:
17+
- !!binary |
18+
YXBwbGljYXRpb24vanNvbg==
19+
method: POST
20+
uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-3-7-sonnet-20250219-v1%3A0/converse
21+
response:
22+
headers:
23+
connection:
24+
- keep-alive
25+
content-length:
26+
- '1085'
27+
content-type:
28+
- application/json
29+
parsed_body:
30+
metrics:
31+
latencyMs: 3896
32+
output:
33+
message:
34+
content:
35+
- reasoningContent:
36+
reasoningText:
37+
signature: ErcBCkgIBhABGAIiQDYN+P1S3ACL3r24cAMKCrRNiuPPxvmT2uzREPLRyKEUXagRbXn97QCke6L7OEZvlh7NdA/MQNTwMZV8TuB4qPASDLNxYxDx1S3luCIfARoMIwLZXwhsvjjTN72XIjAJrEl5ryAvv6C1+6YMCPC73ffE+kgwB96IcZaOuDFQtyaoWwcFcDPBguM6YNp5e3cqHbDQ3QF5dR4PP5q+3K23pual3pUdT/0e7khyIxXkGAI=
38+
text: |-
39+
The user is asking for the largest city in their country. To answer this, I first need to determine what country the user is from. I can use the `get_user_country` function to retrieve this information.
40+
41+
Once I have the user's country, I can then provide information about the largest city in that country.
42+
- text: I'll need to check what country you're from to answer that question.
43+
- toolUse:
44+
input: {}
45+
name: get_user_country
46+
toolUseId: tooluse_W9DaUFg4Tj2cRPpndqxWSg
47+
role: assistant
48+
stopReason: tool_use
49+
usage:
50+
cacheReadInputTokenCount: 0
51+
cacheReadInputTokens: 0
52+
cacheWriteInputTokenCount: 0
53+
cacheWriteInputTokens: 0
54+
inputTokens: 397
55+
outputTokens: 130
56+
totalTokens: 527
57+
status:
58+
code: 200
59+
message: OK
60+
- request:
61+
body: '{"messages": [{"role": "user", "content": [{"text": "What is the largest city in the user country?"}]}, {"role":
62+
"assistant", "content": [{"reasoningContent": {"reasoningText": {"text": "The user is asking for the largest city in
63+
their country. To answer this, I first need to determine what country the user is from. I can use the `get_user_country`
64+
function to retrieve this information.\n\nOnce I have the user''s country, I can then provide information about the
65+
largest city in that country.", "signature": "ErcBCkgIBhABGAIiQDYN+P1S3ACL3r24cAMKCrRNiuPPxvmT2uzREPLRyKEUXagRbXn97QCke6L7OEZvlh7NdA/MQNTwMZV8TuB4qPASDLNxYxDx1S3luCIfARoMIwLZXwhsvjjTN72XIjAJrEl5ryAvv6C1+6YMCPC73ffE+kgwB96IcZaOuDFQtyaoWwcFcDPBguM6YNp5e3cqHbDQ3QF5dR4PP5q+3K23pual3pUdT/0e7khyIxXkGAI="}}},
66+
{"text": "I''ll need to check what country you''re from to answer that question."}, {"toolUse": {"toolUseId": "tooluse_W9DaUFg4Tj2cRPpndqxWSg",
67+
"name": "get_user_country", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "tooluse_W9DaUFg4Tj2cRPpndqxWSg",
68+
"content": [{"text": "Mexico"}], "status": "success"}}]}], "system": [], "inferenceConfig": {}, "toolConfig": {"tools":
69+
[{"toolSpec": {"name": "get_user_country", "inputSchema": {"json": {"additionalProperties": false, "properties": {},
70+
"type": "object"}}}}]}, "additionalModelRequestFields": {"thinking": {"type": "enabled", "budget_tokens": 1024}}}'
71+
headers:
72+
amz-sdk-invocation-id:
73+
- !!binary |
74+
ZWM5NzBkMzYtZTZhYi00MjdlLWFmMzItMTBhNTc2ZjBiMWNl
75+
amz-sdk-request:
76+
- !!binary |
77+
YXR0ZW1wdD0x
78+
content-length:
79+
- '1399'
80+
content-type:
81+
- !!binary |
82+
YXBwbGljYXRpb24vanNvbg==
83+
method: POST
84+
uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-3-7-sonnet-20250219-v1%3A0/converse
85+
response:
86+
headers:
87+
connection:
88+
- keep-alive
89+
content-length:
90+
- '756'
91+
content-type:
92+
- application/json
93+
parsed_body:
94+
metrics:
95+
latencyMs: 4529
96+
output:
97+
message:
98+
content:
99+
- text: |-
100+
Based on your location in Mexico, the largest city is Mexico City (Ciudad de México). It's not only the capital but also the most populous city in Mexico with a metropolitan area population of over 21 million people, making it one of the largest urban agglomerations in the world.
101+
102+
Mexico City is an important cultural, financial, and political center for the country and has a rich history dating back to the Aztec empire when it was known as Tenochtitlán.
103+
role: assistant
104+
stopReason: end_turn
105+
usage:
106+
cacheReadInputTokenCount: 0
107+
cacheReadInputTokens: 0
108+
cacheWriteInputTokenCount: 0
109+
cacheWriteInputTokens: 0
110+
inputTokens: 539
111+
outputTokens: 106
112+
totalTokens: 645
113+
status:
114+
code: 200
115+
message: OK
116+
version: 1

tests/models/test_bedrock.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,29 @@ async def test_bedrock_model_thinking_part(allow_model_requests: None, bedrock_p
654654
)
655655

656656

657+
async def test_bedrock_anthropic_tool_with_thinking(allow_model_requests: None, bedrock_provider: BedrockProvider):
658+
"""When using thinking with tool calls in Anthropic, we need to send the thinking part back to the provider.
659+
660+
This tests the issue raised in https://github.com/pydantic/pydantic-ai/issues/2453.
661+
"""
662+
m = BedrockConverseModel('us.anthropic.claude-3-7-sonnet-20250219-v1:0', provider=bedrock_provider)
663+
settings = BedrockModelSettings(
664+
bedrock_additional_model_requests_fields={'thinking': {'type': 'enabled', 'budget_tokens': 1024}},
665+
)
666+
agent = Agent(m, model_settings=settings)
667+
668+
@agent.tool_plain
669+
async def get_user_country() -> str:
670+
return 'Mexico'
671+
672+
result = await agent.run('What is the largest city in the user country?')
673+
assert result.output == snapshot("""\
674+
Based on your location in Mexico, the largest city is Mexico City (Ciudad de México). It's not only the capital but also the most populous city in Mexico with a metropolitan area population of over 21 million people, making it one of the largest urban agglomerations in the world.
675+
676+
Mexico City is an important cultural, financial, and political center for the country and has a rich history dating back to the Aztec empire when it was known as Tenochtitlán.\
677+
""")
678+
679+
657680
async def test_bedrock_group_consecutive_tool_return_parts(bedrock_provider: BedrockProvider):
658681
"""
659682
Test that consecutive ToolReturnPart objects are grouped into a single user message for Bedrock.

0 commit comments

Comments
 (0)