Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ApprovalRequired,
CallDeferred,
FallbackExceptionGroup,
IncompleteToolCall,
ModelHTTPError,
ModelRetry,
UnexpectedModelBehavior,
Expand Down Expand Up @@ -124,6 +125,7 @@
'ModelRetry',
'ModelHTTPError',
'FallbackExceptionGroup',
'IncompleteToolCall',
'UnexpectedModelBehavior',
'UsageLimitExceeded',
'UserError',
Expand Down
37 changes: 31 additions & 6 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,28 @@ class GraphAgentState:
retries: int = 0
run_step: int = 0

def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None:
def increment_retries(
self,
max_result_retries: int,
error: BaseException | None = None,
model_settings: ModelSettings | None = None,
) -> None:
self.retries += 1
if self.retries > max_result_retries:
if (
self.message_history
and isinstance(model_response := self.message_history[-1], _messages.ModelResponse)
and model_response.finish_reason == 'length'
and model_response.parts
and isinstance(tool_call := model_response.parts[-1], _messages.ToolCallPart)
):
try:
tool_call.args_as_dict()
except Exception:
max_tokens = (model_settings or {}).get('max_tokens') if model_settings else None
raise exceptions.IncompleteToolCall(
f'Model token limit ({max_tokens if max_tokens is not None else "provider default"}) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.'
)
message = f'Exceeded maximum retries ({max_result_retries}) for output validation'
if error:
if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None:
Expand Down Expand Up @@ -568,7 +587,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
# resubmit the most recent request that resulted in an empty response,
# as the empty response and request will not create any items in the API payload,
# in the hope the model will return a non-empty response this time.
ctx.state.increment_retries(ctx.deps.max_result_retries)
ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings)
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[]))
return

Expand Down Expand Up @@ -630,7 +649,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
)
raise ToolRetryError(m)
except ToolRetryError as e:
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
ctx.state.increment_retries(
ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings
)
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))

self._events_iterator = _run_stream()
Expand Down Expand Up @@ -788,10 +809,14 @@ async def process_tool_calls( # noqa: C901
try:
result_data = await tool_manager.handle_call(call)
except exceptions.UnexpectedModelBehavior as e:
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
ctx.state.increment_retries(
ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings
)
raise e # pragma: lax no cover
except ToolRetryError as e:
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
ctx.state.increment_retries(
ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings
)
yield _messages.FunctionToolCallEvent(call)
output_parts.append(e.tool_retry)
yield _messages.FunctionToolResultEvent(e.tool_retry)
Expand Down Expand Up @@ -820,7 +845,7 @@ async def process_tool_calls( # noqa: C901

# Then, we handle unknown tool calls
if tool_calls_by_kind['unknown']:
ctx.state.increment_retries(ctx.deps.max_result_retries)
ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings)
calls_to_run.extend(tool_calls_by_kind['unknown'])

calls_to_run_results: dict[str, DeferredToolResult] = {}
Expand Down
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
'UnexpectedModelBehavior',
'UsageLimitExceeded',
'ModelHTTPError',
'IncompleteToolCall',
'FallbackExceptionGroup',
)

Expand Down Expand Up @@ -168,3 +169,7 @@ class ToolRetryError(Exception):
def __init__(self, tool_retry: RetryPromptPart):
self.tool_retry = tool_retry
super().__init__()


class IncompleteToolCall(UnexpectedModelBehavior):
"""Error raised when a model stops due to token limit while emitting a tool call."""
41 changes: 41 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
DocumentUrl,
FunctionToolset,
ImageUrl,
IncompleteToolCall,
ModelMessage,
ModelMessagesTypeAdapter,
ModelProfile,
Expand Down Expand Up @@ -63,6 +64,7 @@
from pydantic_ai.output import StructuredDict, ToolOutput
from pydantic_ai.result import RunUsage
from pydantic_ai.run import AgentRunResultEvent
from pydantic_ai.settings import ModelSettings
from pydantic_ai.tools import DeferredToolRequests, DeferredToolResults, ToolDefinition, ToolDenied
from pydantic_ai.usage import RequestUsage

Expand Down Expand Up @@ -2448,6 +2450,45 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
)


def test_tool_exceeds_token_limit_error():
def return_incomplete_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
resp = ModelResponse(parts=[ToolCallPart('dummy_tool', args='{"foo": "bar",')])
resp.finish_reason = 'length'
return resp

agent = Agent(FunctionModel(return_incomplete_tool), output_type=str)

with pytest.raises(
IncompleteToolCall,
match=r'Model token limit \(10\) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.',
):
agent.run_sync('Hello', model_settings=ModelSettings(max_tokens=10))

with pytest.raises(
IncompleteToolCall,
match=r'Model token limit \(provider default\) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.',
):
agent.run_sync('Hello')


def test_tool_exceeds_token_limit_but_complete_args():
def return_complete_tool_but_hit_limit(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
if len(messages) == 1:
resp = ModelResponse(parts=[ToolCallPart('dummy_tool', args='{"foo": "bar"}')])
resp.finish_reason = 'length'
return resp
return ModelResponse(parts=[TextPart('done')])

agent = Agent(FunctionModel(return_complete_tool_but_hit_limit), output_type=str)

@agent.tool_plain
def dummy_tool(foo: str) -> str:
return 'tool-ok'

result = agent.run_sync('Hello')
assert result.output == 'done'


def test_model_requests_blocked(env: TestEnv):
try:
env.set('GEMINI_API_KEY', 'foobar')
Expand Down