Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FallbackExceptionGroup,
ModelHTTPError,
ModelRetry,
ToolExceedsTokenLimitError,
UnexpectedModelBehavior,
UsageLimitExceeded,
UserError,
Expand Down Expand Up @@ -124,6 +125,7 @@
'ModelRetry',
'ModelHTTPError',
'FallbackExceptionGroup',
'ToolExceedsTokenLimitError',
'UnexpectedModelBehavior',
'UsageLimitExceeded',
'UserError',
Expand Down
32 changes: 26 additions & 6 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,24 @@ class GraphAgentState:
retries: int = 0
run_step: int = 0

def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None:
def increment_retries(
self, max_result_retries: int, error: BaseException | None = None, max_tokens: int | None = None
) -> None:
self.retries += 1
if self.retries > max_result_retries:
if (
self.message_history
and isinstance(model_response := self.message_history[-1], _messages.ModelResponse)
and model_response.finish_reason == 'length'
and model_response.parts
and isinstance(tool_call := model_response.parts[-1], _messages.ToolCallPart)
):
try:
tool_call.args_as_dict()
except Exception:
raise exceptions.ToolExceedsTokenLimitError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be a subclass of UnexpectedModelBehavior for backward compatibility, and I suggest renaming it IncompleteToolCall

f'Model token limit ({max_tokens if max_tokens is not None else "provider default"}) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.'
)
message = f'Exceeded maximum retries ({max_result_retries}) for output validation'
if error:
if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None:
Expand Down Expand Up @@ -568,7 +583,8 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
# resubmit the most recent request that resulted in an empty response,
# as the empty response and request will not create any items in the API payload,
# in the hope the model will return a non-empty response this time.
ctx.state.increment_retries(ctx.deps.max_result_retries)
max_tokens = (ctx.deps.model_settings or {}).get('max_tokens') if ctx.deps.model_settings else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pass in the entire model_settings object to clean these repeated sections up a bit?

ctx.state.increment_retries(ctx.deps.max_result_retries, max_tokens=max_tokens)
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[]))
return

Expand Down Expand Up @@ -630,7 +646,8 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
)
raise ToolRetryError(m)
except ToolRetryError as e:
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
max_tokens = (ctx.deps.model_settings or {}).get('max_tokens') if ctx.deps.model_settings else None
ctx.state.increment_retries(ctx.deps.max_result_retries, error=e, max_tokens=max_tokens)
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))

self._events_iterator = _run_stream()
Expand Down Expand Up @@ -776,10 +793,12 @@ async def process_tool_calls( # noqa: C901
try:
result_data = await tool_manager.handle_call(call)
except exceptions.UnexpectedModelBehavior as e:
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
max_tokens = (ctx.deps.model_settings or {}).get('max_tokens') if ctx.deps.model_settings else None
ctx.state.increment_retries(ctx.deps.max_result_retries, error=e, max_tokens=max_tokens)
raise e # pragma: lax no cover
except ToolRetryError as e:
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
max_tokens = (ctx.deps.model_settings or {}).get('max_tokens') if ctx.deps.model_settings else None
ctx.state.increment_retries(ctx.deps.max_result_retries, error=e, max_tokens=max_tokens)
yield _messages.FunctionToolCallEvent(call)
output_parts.append(e.tool_retry)
yield _messages.FunctionToolResultEvent(e.tool_retry)
Expand Down Expand Up @@ -808,7 +827,8 @@ async def process_tool_calls( # noqa: C901

# Then, we handle unknown tool calls
if tool_calls_by_kind['unknown']:
ctx.state.increment_retries(ctx.deps.max_result_retries)
max_tokens = (ctx.deps.model_settings or {}).get('max_tokens') if ctx.deps.model_settings else None
ctx.state.increment_retries(ctx.deps.max_result_retries, max_tokens=max_tokens)
calls_to_run.extend(tool_calls_by_kind['unknown'])

calls_to_run_results: dict[str, DeferredToolResult] = {}
Expand Down
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
'UnexpectedModelBehavior',
'UsageLimitExceeded',
'ModelHTTPError',
'ToolExceedsTokenLimitError',
'FallbackExceptionGroup',
)

Expand Down Expand Up @@ -168,3 +169,7 @@ class ToolRetryError(Exception):
def __init__(self, tool_retry: RetryPromptPart):
self.tool_retry = tool_retry
super().__init__()


class ToolExceedsTokenLimitError(AgentRunError):
"""Error raised when a model stops due to token limit while emitting a tool call."""
41 changes: 41 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
SystemPromptPart,
TextPart,
ToolCallPart,
ToolExceedsTokenLimitError,
ToolReturn,
ToolReturnPart,
UnexpectedModelBehavior,
Expand All @@ -63,6 +64,7 @@
from pydantic_ai.output import StructuredDict, ToolOutput
from pydantic_ai.result import RunUsage
from pydantic_ai.run import AgentRunResultEvent
from pydantic_ai.settings import ModelSettings
from pydantic_ai.tools import DeferredToolRequests, DeferredToolResults, ToolDefinition, ToolDenied
from pydantic_ai.usage import RequestUsage

Expand Down Expand Up @@ -2448,6 +2450,45 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
)


def test_tool_exceeds_token_limit_error():
def return_incomplete_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
resp = ModelResponse(parts=[ToolCallPart('dummy_tool', args='{"foo": "bar",')])
resp.finish_reason = 'length'
return resp

agent = Agent(FunctionModel(return_incomplete_tool), output_type=str)

with pytest.raises(
ToolExceedsTokenLimitError,
match=r'Model token limit \(10\) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.',
):
agent.run_sync('Hello', model_settings=ModelSettings(max_tokens=10))

with pytest.raises(
ToolExceedsTokenLimitError,
match=r'Model token limit \(provider default\) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.',
):
agent.run_sync('Hello')


def test_tool_exceeds_token_limit_but_complete_args():
def return_complete_tool_but_hit_limit(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
if len(messages) == 1:
resp = ModelResponse(parts=[ToolCallPart('dummy_tool', args='{"foo": "bar"}')])
resp.finish_reason = 'length'
return resp
return ModelResponse(parts=[TextPart('done')])

agent = Agent(FunctionModel(return_complete_tool_but_hit_limit), output_type=str)

@agent.tool_plain
def dummy_tool(foo: str) -> str:
return 'tool-ok'

result = agent.run_sync('Hello')
assert result.output == 'done'


def test_model_requests_blocked(env: TestEnv):
try:
env.set('GEMINI_API_KEY', 'foobar')
Expand Down