Skip to content

Commit 269ef33

Browse files
committed
refactor to make diff smaller
1 parent b9a125c commit 269ef33

File tree

1 file changed

+80
-99
lines changed

1 file changed

+80
-99
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 80 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,12 @@ def is_agent_node(
143143
return isinstance(node, AgentNode)
144144

145145

146-
def _is_retry_attempt(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]) -> bool:
147-
# Check if we've already attempted a thinking-only retry to prevent infinite loops
148-
recent_messages = (
149-
ctx.state.message_history[-3:] if len(ctx.state.message_history) >= 3 else ctx.state.message_history
150-
)
151-
for msg in recent_messages:
146+
def _is_retry_attempt(message_history: list[_messages.ModelMessage]) -> bool:
147+
"""Check if we've already attempted a thinking-only retry to prevent infinite loops.
148+
149+
This is admittedly a hack, so please propose a more type-safe solution.
150+
"""
151+
for msg in message_history[-3:]:
152152
if isinstance(msg, _messages.ModelRequest):
153153
for part in msg.parts:
154154
if (
@@ -160,47 +160,25 @@ def _is_retry_attempt(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
160160
return False
161161

162162

163-
async def _create_thinking_retry(
164-
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
165-
) -> ModelRequestNode[DepsT, NodeRunEndT]:
166-
# Create retry prompt
167-
retry_prompt = (
168-
'Based on your thinking above, you MUST now provide '
169-
'a specific answer or use the available tools to complete the task. '
170-
'Do not respond with only thinking content. Provide actionable output.'
171-
)
172-
173-
# Create the retry request using UserPromptPart for API compatibility
174-
# We'll use a special content marker to detect this is a thinking retry
175-
retry_part = _messages.UserPromptPart(f'[THINKING_RETRY] {retry_prompt}')
176-
retry_request = _messages.ModelRequest(parts=[retry_part])
177-
178-
# Create new ModelRequestNode for retry (it will add the request to message history)
179-
return ModelRequestNode[DepsT, NodeRunEndT](request=retry_request)
163+
def _create_thinking_retry_request(
164+
parts: list[_messages.ModelResponsePart], message_history: list[_messages.ModelMessage]
165+
) -> _messages.ModelRequest | None:
166+
# Handle thinking-only responses (responses that contain only ThinkingPart instances)
167+
# This can happen with models that support thinking mode when they don't provide
168+
# actionable output alongside their thinking content.
169+
thinking_parts = [p for p in parts if isinstance(p, _messages.ThinkingPart)]
170+
if thinking_parts and not _is_retry_attempt(message_history):
171+
# Create retry prompt
172+
retry_prompt = (
173+
'Based on your thinking above, you MUST now provide '
174+
'a specific answer or use the available tools to complete the task. '
175+
'Do not respond with only thinking content. Provide actionable output.'
176+
)
180177

181-
182-
async def _process_response_parts(
183-
parts: list[_messages.ModelResponsePart], texts: list[str], tool_calls: list[_messages.ToolCallPart]
184-
) -> AsyncIterator[_messages.HandleResponseEvent]:
185-
for part in parts:
186-
if isinstance(part, _messages.TextPart):
187-
# ignore empty content for text parts, see #437
188-
if part.content:
189-
texts.append(part.content)
190-
elif isinstance(part, _messages.ToolCallPart):
191-
tool_calls.append(part)
192-
elif isinstance(part, _messages.BuiltinToolCallPart):
193-
yield _messages.BuiltinToolCallEvent(part)
194-
elif isinstance(part, _messages.BuiltinToolReturnPart):
195-
yield _messages.BuiltinToolResultEvent(part)
196-
elif isinstance(part, _messages.ThinkingPart):
197-
# We don't need to do anything with thinking parts in this tool-calling node.
198-
# We need to handle text parts in case there are no tool calls and/or the desired output comes
199-
# from the text, but thinking parts should not directly influence the execution of tools or
200-
# determination of the next node of graph execution here.
201-
pass
202-
else:
203-
assert_never(part)
178+
# Create the retry request using UserPromptPart for API compatibility
179+
# We'll use a special content marker to detect this is a thinking retry
180+
retry_part = _messages.UserPromptPart(f'[THINKING_RETRY] {retry_prompt}')
181+
return _messages.ModelRequest(parts=[retry_part])
204182

205183

206184
@dataclasses.dataclass
@@ -490,67 +468,70 @@ async def stream(
490468
async for _event in stream:
491469
pass
492470

493-
async def _run_stream(
471+
async def _run_stream( # noqa: C901
494472
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
495473
) -> AsyncIterator[_messages.HandleResponseEvent]:
496474
if self._events_iterator is None:
497475
# Ensure that the stream is only run once
498-
self._events_iterator = self._create_stream_iterator(ctx)
476+
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa: C901
477+
texts: list[str] = []
478+
tool_calls: list[_messages.ToolCallPart] = []
479+
for part in self.model_response.parts:
480+
if isinstance(part, _messages.TextPart):
481+
# ignore empty content for text parts, see #437
482+
if part.content:
483+
texts.append(part.content)
484+
elif isinstance(part, _messages.ToolCallPart):
485+
tool_calls.append(part)
486+
elif isinstance(part, _messages.BuiltinToolCallPart):
487+
yield _messages.BuiltinToolCallEvent(part)
488+
elif isinstance(part, _messages.BuiltinToolReturnPart):
489+
yield _messages.BuiltinToolResultEvent(part)
490+
elif isinstance(part, _messages.ThinkingPart):
491+
# We don't need to do anything with thinking parts in this tool-calling node.
492+
# We need to handle text parts in case there are no tool calls and/or the desired output comes
493+
# from the text, but thinking parts should not directly influence the execution of tools or
494+
# determination of the next node of graph execution here.
495+
pass
496+
else:
497+
assert_never(part)
498+
499+
# At the moment, we prioritize at least executing tool calls if they are present.
500+
# In the future, we'd consider making this configurable at the agent or run level.
501+
# This accounts for cases like anthropic returns that might contain a text response
502+
# and a tool call response, where the text response just indicates the tool call will happen.
503+
if tool_calls:
504+
async for event in self._handle_tool_calls(ctx, tool_calls):
505+
yield event
506+
elif texts:
507+
# No events are emitted during the handling of text responses, so we don't need to yield anything
508+
self._next_node = await self._handle_text_response(ctx, texts)
509+
else:
510+
if retry_request := _create_thinking_retry_request(
511+
self.model_response.parts, ctx.state.message_history
512+
):
513+
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](request=retry_request)
514+
return
515+
516+
# we've got an empty response, this sometimes happens with anthropic (and perhaps other models)
517+
# when the model has already returned text along side tool calls
518+
# in this scenario, if text responses are allowed, we return text from the most recent model
519+
# response, if any
520+
if isinstance(ctx.deps.output_schema, _output.TextOutputSchema):
521+
for message in reversed(ctx.state.message_history):
522+
if isinstance(message, _messages.ModelResponse):
523+
last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)]
524+
if last_texts:
525+
self._next_node = await self._handle_text_response(ctx, last_texts)
526+
return
527+
528+
raise exceptions.UnexpectedModelBehavior('Received empty model response')
529+
530+
self._events_iterator = _run_stream()
499531

500532
async for event in self._events_iterator:
501533
yield event
502534

503-
async def _create_stream_iterator(
504-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
505-
) -> AsyncIterator[_messages.HandleResponseEvent]:
506-
texts: list[str] = []
507-
tool_calls: list[_messages.ToolCallPart] = []
508-
509-
# Process all parts in the model response
510-
async for event in _process_response_parts(self.model_response.parts, texts, tool_calls):
511-
yield event
512-
513-
# Handle the response based on what we found
514-
if tool_calls:
515-
async for event in self._handle_tool_calls(ctx, tool_calls):
516-
yield event
517-
elif texts:
518-
# No events are emitted during the handling of text responses, so we don't need to yield anything
519-
self._next_node = await self._handle_text_response(ctx, texts)
520-
else:
521-
self._next_node = await self._handle_empty_response(ctx)
522-
523-
async def _handle_empty_response(
524-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
525-
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
526-
# Handle thinking-only responses (responses that contain only ThinkingPart instances)
527-
# This can happen with models that support thinking mode when they don't provide
528-
# actionable output alongside their thinking content.
529-
thinking_parts = [p for p in self.model_response.parts if isinstance(p, _messages.ThinkingPart)]
530-
531-
if thinking_parts and not _is_retry_attempt(ctx):
532-
return await _create_thinking_retry(ctx)
533-
534-
# Original recovery logic - this sometimes happens with anthropic (and perhaps other models)
535-
# when the model has already returned text along side tool calls
536-
# in this scenario, if text responses are allowed, we return text from the most recent model
537-
# response, if any
538-
if isinstance(ctx.deps.output_schema, _output.TextOutputSchema):
539-
if next_node := await self._try_recover_from_history(ctx):
540-
return next_node
541-
542-
raise exceptions.UnexpectedModelBehavior('Received empty model response')
543-
544-
async def _try_recover_from_history(
545-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
546-
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None:
547-
for message in reversed(ctx.state.message_history):
548-
if isinstance(message, _messages.ModelResponse):
549-
last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)]
550-
if last_texts:
551-
return await self._handle_text_response(ctx, last_texts)
552-
return None
553-
554535
async def _handle_tool_calls(
555536
self,
556537
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],

0 commit comments

Comments
 (0)