Skip to content

Commit d8c60c1

Browse files
committed
add handling for thinking-only requests (currently causes UnexpectedModelBehavior)
1 parent f25a4e1 commit d8c60c1

File tree

2 files changed

+117
-52
lines changed

2 files changed

+117
-52
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 113 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,66 @@ def is_agent_node(
143143
return isinstance(node, AgentNode)
144144

145145

146+
def _is_retry_attempt(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]) -> bool:
147+
# Check if we've already attempted a thinking-only retry to prevent infinite loops
148+
recent_messages = (
149+
ctx.state.message_history[-3:] if len(ctx.state.message_history) >= 3 else ctx.state.message_history
150+
)
151+
for msg in recent_messages:
152+
if isinstance(msg, _messages.ModelRequest):
153+
for part in msg.parts:
154+
if (
155+
isinstance(part, _messages.UserPromptPart)
156+
and isinstance(part.content, str)
157+
and part.content.startswith('[THINKING_RETRY]')
158+
):
159+
return True
160+
return False
161+
162+
163+
async def _create_thinking_retry(
164+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
165+
) -> ModelRequestNode[DepsT, NodeRunEndT]:
166+
# Create retry prompt
167+
retry_prompt = (
168+
'Based on your thinking above, you MUST now provide '
169+
'a specific answer or use the available tools to complete the task. '
170+
'Do not respond with only thinking content. Provide actionable output.'
171+
)
172+
173+
# Create the retry request using UserPromptPart for API compatibility
174+
# We'll use a special content marker to detect this is a thinking retry
175+
retry_part = _messages.UserPromptPart(f'[THINKING_RETRY] {retry_prompt}')
176+
retry_request = _messages.ModelRequest(parts=[retry_part])
177+
178+
# Create new ModelRequestNode for retry (it will add the request to message history)
179+
return ModelRequestNode[DepsT, NodeRunEndT](request=retry_request)
180+
181+
182+
async def _process_response_parts(
183+
parts: list[_messages.ModelResponsePart], texts: list[str], tool_calls: list[_messages.ToolCallPart]
184+
) -> AsyncIterator[_messages.HandleResponseEvent]:
185+
for part in parts:
186+
if isinstance(part, _messages.TextPart):
187+
# ignore empty content for text parts, see #437
188+
if part.content:
189+
texts.append(part.content)
190+
elif isinstance(part, _messages.ToolCallPart):
191+
tool_calls.append(part)
192+
elif isinstance(part, _messages.BuiltinToolCallPart):
193+
yield _messages.BuiltinToolCallEvent(part)
194+
elif isinstance(part, _messages.BuiltinToolReturnPart):
195+
yield _messages.BuiltinToolResultEvent(part)
196+
elif isinstance(part, _messages.ThinkingPart):
197+
# We don't need to do anything with thinking parts in this tool-calling node.
198+
# We need to handle text parts in case there are no tool calls and/or the desired output comes
199+
# from the text, but thinking parts should not directly influence the execution of tools or
200+
# determination of the next node of graph execution here.
201+
pass
202+
else:
203+
assert_never(part)
204+
205+
146206
@dataclasses.dataclass
147207
class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
148208
"""The node that handles the user prompt and instructions."""
@@ -428,65 +488,67 @@ async def stream(
428488
async for _event in stream:
429489
pass
430490

431-
async def _run_stream( # noqa: C901
491+
async def _run_stream(
432492
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
433493
) -> AsyncIterator[_messages.HandleResponseEvent]:
434494
if self._events_iterator is None:
435495
# Ensure that the stream is only run once
436-
437-
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
438-
texts: list[str] = []
439-
tool_calls: list[_messages.ToolCallPart] = []
440-
for part in self.model_response.parts:
441-
if isinstance(part, _messages.TextPart):
442-
# ignore empty content for text parts, see #437
443-
if part.content:
444-
texts.append(part.content)
445-
elif isinstance(part, _messages.ToolCallPart):
446-
tool_calls.append(part)
447-
elif isinstance(part, _messages.BuiltinToolCallPart):
448-
yield _messages.BuiltinToolCallEvent(part)
449-
elif isinstance(part, _messages.BuiltinToolReturnPart):
450-
yield _messages.BuiltinToolResultEvent(part)
451-
elif isinstance(part, _messages.ThinkingPart):
452-
# We don't need to do anything with thinking parts in this tool-calling node.
453-
# We need to handle text parts in case there are no tool calls and/or the desired output comes
454-
# from the text, but thinking parts should not directly influence the execution of tools or
455-
# determination of the next node of graph execution here.
456-
pass
457-
else:
458-
assert_never(part)
459-
460-
# At the moment, we prioritize at least executing tool calls if they are present.
461-
# In the future, we'd consider making this configurable at the agent or run level.
462-
# This accounts for cases like anthropic returns that might contain a text response
463-
# and a tool call response, where the text response just indicates the tool call will happen.
464-
if tool_calls:
465-
async for event in self._handle_tool_calls(ctx, tool_calls):
466-
yield event
467-
elif texts:
468-
# No events are emitted during the handling of text responses, so we don't need to yield anything
469-
self._next_node = await self._handle_text_response(ctx, texts)
470-
else:
471-
# we've got an empty response, this sometimes happens with anthropic (and perhaps other models)
472-
# when the model has already returned text along side tool calls
473-
# in this scenario, if text responses are allowed, we return text from the most recent model
474-
# response, if any
475-
if isinstance(ctx.deps.output_schema, _output.TextOutputSchema):
476-
for message in reversed(ctx.state.message_history):
477-
if isinstance(message, _messages.ModelResponse):
478-
last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)]
479-
if last_texts:
480-
self._next_node = await self._handle_text_response(ctx, last_texts)
481-
return
482-
483-
raise exceptions.UnexpectedModelBehavior('Received empty model response')
484-
485-
self._events_iterator = _run_stream()
496+
self._events_iterator = self._create_stream_iterator(ctx)
486497

487498
async for event in self._events_iterator:
488499
yield event
489500

501+
async def _create_stream_iterator(
502+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
503+
) -> AsyncIterator[_messages.HandleResponseEvent]:
504+
texts: list[str] = []
505+
tool_calls: list[_messages.ToolCallPart] = []
506+
507+
# Process all parts in the model response
508+
async for event in _process_response_parts(self.model_response.parts, texts, tool_calls):
509+
yield event
510+
511+
# Handle the response based on what we found
512+
if tool_calls:
513+
async for event in self._handle_tool_calls(ctx, tool_calls):
514+
yield event
515+
elif texts:
516+
# No events are emitted during the handling of text responses, so we don't need to yield anything
517+
self._next_node = await self._handle_text_response(ctx, texts)
518+
else:
519+
self._next_node = await self._handle_empty_response(ctx)
520+
521+
async def _handle_empty_response(
522+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
523+
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
524+
# Handle thinking-only responses (responses that contain only ThinkingPart instances)
525+
# This can happen with models that support thinking mode when they don't provide
526+
# actionable output alongside their thinking content.
527+
thinking_parts = [p for p in self.model_response.parts if isinstance(p, _messages.ThinkingPart)]
528+
529+
if thinking_parts and not _is_retry_attempt(ctx):
530+
return await _create_thinking_retry(ctx)
531+
532+
# Original recovery logic - this sometimes happens with anthropic (and perhaps other models)
533+
# when the model has already returned text along side tool calls
534+
# in this scenario, if text responses are allowed, we return text from the most recent model
535+
# response, if any
536+
if isinstance(ctx.deps.output_schema, _output.TextOutputSchema):
537+
if next_node := await self._try_recover_from_history(ctx):
538+
return next_node
539+
540+
raise exceptions.UnexpectedModelBehavior('Received empty model response')
541+
542+
async def _try_recover_from_history(
543+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
544+
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None:
545+
for message in reversed(ctx.state.message_history):
546+
if isinstance(message, _messages.ModelResponse):
547+
last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)]
548+
if last_texts:
549+
return await self._handle_text_response(ctx, last_texts)
550+
return None
551+
490552
async def _handle_tool_calls(
491553
self,
492554
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,10 @@ async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict
457457
message_parts = [{'text': ''}]
458458
contents.append({'role': 'user', 'parts': message_parts})
459459
elif isinstance(m, ModelResponse):
460-
contents.append(_content_model_response(m))
460+
model_content = _content_model_response(m)
461+
# Skip model responses with empty parts (e.g., thinking-only responses)
462+
if model_content.get('parts'):
463+
contents.append(model_content)
461464
else:
462465
assert_never(m)
463466
if instructions := self._get_instructions(messages):

0 commit comments

Comments
 (0)