Skip to content

Commit fd9d593

Browse files
authored
Give model a chance to retry after producing an empty response (#2961)
1 parent 35dcc6a commit fd9d593

File tree

6 files changed

+170
-103
lines changed

6 files changed

+170
-103
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 59 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ async def _run_stream( # noqa: C901
547547
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa: C901
548548
text = ''
549549
tool_calls: list[_messages.ToolCallPart] = []
550-
thinking_parts: list[_messages.ThinkingPart] = []
550+
invisible_parts: bool = False
551551

552552
for part in self.model_response.parts:
553553
if isinstance(part, _messages.TextPart):
@@ -558,55 +558,65 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
558558
# Text parts before a built-in tool call are essentially thoughts,
559559
# not part of the final result output, so we reset the accumulated text
560560
text = ''
561+
invisible_parts = True
561562
yield _messages.BuiltinToolCallEvent(part) # pyright: ignore[reportDeprecated]
562563
elif isinstance(part, _messages.BuiltinToolReturnPart):
564+
invisible_parts = True
563565
yield _messages.BuiltinToolResultEvent(part) # pyright: ignore[reportDeprecated]
564566
elif isinstance(part, _messages.ThinkingPart):
565-
thinking_parts.append(part)
567+
invisible_parts = True
566568
else:
567569
assert_never(part)
568570

569571
# At the moment, we prioritize at least executing tool calls if they are present.
570572
# In the future, we'd consider making this configurable at the agent or run level.
571573
# This accounts for cases like anthropic returns that might contain a text response
572574
# and a tool call response, where the text response just indicates the tool call will happen.
573-
if tool_calls:
574-
async for event in self._handle_tool_calls(ctx, tool_calls):
575-
yield event
576-
elif text:
577-
# No events are emitted during the handling of text responses, so we don't need to yield anything
578-
self._next_node = await self._handle_text_response(ctx, text)
579-
elif thinking_parts:
580-
# handle thinking-only responses (responses that contain only ThinkingPart instances)
581-
# this can happen with models that support thinking mode when they don't provide
582-
# actionable output alongside their thinking content.
583-
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
584-
_messages.ModelRequest(
585-
parts=[_messages.RetryPromptPart('Responses without text or tool calls are not permitted.')]
575+
try:
576+
if tool_calls:
577+
async for event in self._handle_tool_calls(ctx, tool_calls):
578+
yield event
579+
elif text:
580+
# No events are emitted during the handling of text responses, so we don't need to yield anything
581+
self._next_node = await self._handle_text_response(ctx, text)
582+
elif invisible_parts:
583+
# handle responses with only thinking or built-in tool parts.
584+
# this can happen with models that support thinking mode when they don't provide
585+
# actionable output alongside their thinking content. so we tell the model to try again.
586+
m = _messages.RetryPromptPart(
587+
content='Responses without text or tool calls are not permitted.',
586588
)
587-
)
588-
else:
589-
# we got an empty response with no tool calls, text, or thinking
590-
# this sometimes happens with anthropic (and perhaps other models)
591-
# when the model has already returned text along side tool calls
592-
# in this scenario, if text responses are allowed, we return text from the most recent model
593-
# response, if any
594-
if isinstance(ctx.deps.output_schema, _output.TextOutputSchema):
595-
for message in reversed(ctx.state.message_history):
596-
if isinstance(message, _messages.ModelResponse):
597-
text = ''
598-
for part in message.parts:
599-
if isinstance(part, _messages.TextPart):
600-
text += part.content
601-
elif isinstance(part, _messages.BuiltinToolCallPart):
602-
# Text parts before a built-in tool call are essentially thoughts,
603-
# not part of the final result output, so we reset the accumulated text
604-
text = '' # pragma: no cover
605-
if text:
606-
self._next_node = await self._handle_text_response(ctx, text)
607-
return
608-
609-
raise exceptions.UnexpectedModelBehavior('Received empty model response')
589+
raise ToolRetryError(m)
590+
else:
591+
# we got an empty response with no tool calls, text, thinking, or built-in tool calls.
592+
# this sometimes happens with anthropic (and perhaps other models)
593+
# when the model has already returned text along side tool calls
594+
# in this scenario, if text responses are allowed, we return text from the most recent model
595+
# response, if any
596+
if isinstance(ctx.deps.output_schema, _output.TextOutputSchema):
597+
for message in reversed(ctx.state.message_history):
598+
if isinstance(message, _messages.ModelResponse):
599+
text = ''
600+
for part in message.parts:
601+
if isinstance(part, _messages.TextPart):
602+
text += part.content
603+
elif isinstance(part, _messages.BuiltinToolCallPart):
604+
# Text parts before a built-in tool call are essentially thoughts,
605+
# not part of the final result output, so we reset the accumulated text
606+
text = '' # pragma: no cover
607+
if text:
608+
self._next_node = await self._handle_text_response(ctx, text)
609+
return
610+
611+
# Go back to the model request node with an empty request, which means we'll essentially
612+
# resubmit the most recent request that resulted in an empty response,
613+
# as the empty response and request will not create any items in the API payload,
614+
# in the hope the model will return a non-empty response this time.
615+
ctx.state.increment_retries(ctx.deps.max_result_retries)
616+
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[]))
617+
except ToolRetryError as e:
618+
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
619+
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
610620

611621
self._events_iterator = _run_stream()
612622

@@ -666,23 +676,19 @@ async def _handle_text_response(
666676
text: str,
667677
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
668678
output_schema = ctx.deps.output_schema
669-
try:
670-
run_context = build_run_context(ctx)
671-
if isinstance(output_schema, _output.TextOutputSchema):
672-
result_data = await output_schema.process(text, run_context)
673-
else:
674-
m = _messages.RetryPromptPart(
675-
content='Plain text responses are not permitted, please include your response in a tool call',
676-
)
677-
raise ToolRetryError(m)
679+
run_context = build_run_context(ctx)
678680

679-
for validator in ctx.deps.output_validators:
680-
result_data = await validator.validate(result_data, run_context)
681-
except ToolRetryError as e:
682-
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
683-
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
681+
if isinstance(output_schema, _output.TextOutputSchema):
682+
result_data = await output_schema.process(text, run_context)
684683
else:
685-
return self._handle_final_result(ctx, result.FinalResult(result_data), [])
684+
m = _messages.RetryPromptPart(
685+
content='Plain text responses are not permitted, please include your response in a tool call',
686+
)
687+
raise ToolRetryError(m)
688+
689+
for validator in ctx.deps.output_validators:
690+
result_data = await validator.validate(result_data, run_context)
691+
return self._handle_final_result(ctx, result.FinalResult(result_data), [])
686692

687693
__repr__ = dataclasses_no_defaults_repr
688694

tests/models/test_groq.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pydantic import BaseModel
1616
from typing_extensions import TypedDict
1717

18-
from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior
18+
from pydantic_ai import Agent, ModelHTTPError, ModelRetry
1919
from pydantic_ai.builtin_tools import WebSearchTool
2020
from pydantic_ai.messages import (
2121
BinaryContent,
@@ -533,17 +533,6 @@ async def test_stream_structured_finish_reason(allow_model_requests: None):
533533
assert result.is_complete
534534

535535

536-
async def test_no_content(allow_model_requests: None):
537-
stream = chunk([ChoiceDelta()]), chunk([ChoiceDelta()])
538-
mock_client = MockGroq.create_mock_stream(stream)
539-
m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client))
540-
agent = Agent(m, output_type=MyTypedDict)
541-
542-
with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
543-
async with agent.run_stream(''):
544-
pass
545-
546-
547536
async def test_no_delta(allow_model_requests: None):
548537
stream = chunk([]), text_chunk('hello '), text_chunk('world')
549538
mock_client = MockGroq.create_mock_stream(stream)

tests/models/test_huggingface.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from inline_snapshot import snapshot
1414
from typing_extensions import TypedDict
1515

16-
from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior
16+
from pydantic_ai import Agent, ModelRetry
1717
from pydantic_ai.exceptions import ModelHTTPError
1818
from pydantic_ai.messages import (
1919
AudioUrl,
@@ -601,20 +601,6 @@ async def test_stream_structured_finish_reason(allow_model_requests: None):
601601
assert result.is_complete
602602

603603

604-
async def test_no_content(allow_model_requests: None):
605-
stream = [
606-
chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore
607-
chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore
608-
]
609-
mock_client = MockHuggingFace.create_stream_mock(stream)
610-
m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x'))
611-
agent = Agent(m, output_type=MyTypedDict)
612-
613-
with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
614-
async with agent.run_stream(''):
615-
pass
616-
617-
618604
async def test_no_delta(allow_model_requests: None):
619605
stream = [
620606
chunk([]),

tests/models/test_openai.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -606,17 +606,6 @@ async def test_stream_text_empty_think_tag_and_text_before_tool_call(allow_model
606606
assert await result.get_output() == snapshot({'first': 'One', 'second': 'Two'})
607607

608608

609-
async def test_no_content(allow_model_requests: None):
610-
stream = [chunk([ChoiceDelta()]), chunk([ChoiceDelta()])]
611-
mock_client = MockOpenAI.create_mock_stream(stream)
612-
m = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
613-
agent = Agent(m, output_type=MyTypedDict)
614-
615-
with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
616-
async with agent.run_stream(''):
617-
pass
618-
619-
620609
async def test_no_delta(allow_model_requests: None):
621610
stream = [
622611
chunk([]),

tests/test_agent.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2168,14 +2168,79 @@ def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelRes
21682168
assert result.new_messages() == []
21692169

21702170

2171-
def test_empty_tool_calls():
2172-
def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
2171+
def test_empty_response():
2172+
def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
2173+
if len(messages) == 1:
2174+
return ModelResponse(parts=[])
2175+
else:
2176+
return ModelResponse(parts=[TextPart('ok here is text')])
2177+
2178+
agent = Agent(FunctionModel(llm))
2179+
2180+
result = agent.run_sync('Hello')
2181+
2182+
assert result.all_messages() == snapshot(
2183+
[
2184+
ModelRequest(
2185+
parts=[
2186+
UserPromptPart(
2187+
content='Hello',
2188+
timestamp=IsDatetime(),
2189+
)
2190+
]
2191+
),
2192+
ModelResponse(
2193+
parts=[],
2194+
usage=RequestUsage(input_tokens=51),
2195+
model_name='function:llm:',
2196+
timestamp=IsDatetime(),
2197+
),
2198+
ModelRequest(parts=[]),
2199+
ModelResponse(
2200+
parts=[TextPart(content='ok here is text')],
2201+
usage=RequestUsage(input_tokens=51, output_tokens=4),
2202+
model_name='function:llm:',
2203+
timestamp=IsDatetime(),
2204+
),
2205+
]
2206+
)
2207+
2208+
2209+
def test_empty_response_without_recovery():
2210+
def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
21732211
return ModelResponse(parts=[])
21742212

2175-
agent = Agent(FunctionModel(empty))
2213+
agent = Agent(FunctionModel(llm), output_type=tuple[str, int])
21762214

2177-
with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
2178-
agent.run_sync('Hello')
2215+
with capture_run_messages() as messages:
2216+
with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for output validation'):
2217+
agent.run_sync('Hello')
2218+
2219+
assert messages == snapshot(
2220+
[
2221+
ModelRequest(
2222+
parts=[
2223+
UserPromptPart(
2224+
content='Hello',
2225+
timestamp=IsDatetime(),
2226+
)
2227+
]
2228+
),
2229+
ModelResponse(
2230+
parts=[],
2231+
usage=RequestUsage(input_tokens=51),
2232+
model_name='function:llm:',
2233+
timestamp=IsDatetime(),
2234+
),
2235+
ModelRequest(parts=[]),
2236+
ModelResponse(
2237+
parts=[],
2238+
usage=RequestUsage(input_tokens=51),
2239+
model_name='function:llm:',
2240+
timestamp=IsDatetime(),
2241+
),
2242+
]
2243+
)
21792244

21802245

21812246
def test_unknown_tool():

tests/test_streaming.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -400,15 +400,47 @@ async def ret_a(x: str) -> str:
400400
)
401401

402402

403-
async def test_call_tool_empty():
404-
async def stream_structured_function(_messages: list[ModelMessage], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
405-
yield {}
403+
async def test_empty_response():
404+
async def stream_structured_function(
405+
messages: list[ModelMessage], _: AgentInfo
406+
) -> AsyncIterator[DeltaToolCalls | str]:
407+
if len(messages) == 1:
408+
yield {}
409+
else:
410+
yield 'ok here is text'
406411

407-
agent = Agent(FunctionModel(stream_function=stream_structured_function), output_type=tuple[str, int])
412+
agent = Agent(FunctionModel(stream_function=stream_structured_function))
408413

409-
with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
410-
async with agent.run_stream('hello'):
411-
pass
414+
async with agent.run_stream('hello') as result:
415+
response = await result.get_output()
416+
assert response == snapshot('ok here is text')
417+
messages = result.all_messages()
418+
419+
assert messages == snapshot(
420+
[
421+
ModelRequest(
422+
parts=[
423+
UserPromptPart(
424+
content='hello',
425+
timestamp=IsDatetime(),
426+
)
427+
]
428+
),
429+
ModelResponse(
430+
parts=[],
431+
usage=RequestUsage(input_tokens=50),
432+
model_name='function::stream_structured_function',
433+
timestamp=IsDatetime(),
434+
),
435+
ModelRequest(parts=[]),
436+
ModelResponse(
437+
parts=[TextPart(content='ok here is text')],
438+
usage=RequestUsage(input_tokens=50, output_tokens=4),
439+
model_name='function::stream_structured_function',
440+
timestamp=IsDatetime(),
441+
),
442+
]
443+
)
412444

413445

414446
async def test_call_tool_wrong_name():

0 commit comments

Comments
 (0)