diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 8f6254f425..77b605b5ee 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -22,6 +22,7 @@ ApprovalRequired, CallDeferred, FallbackExceptionGroup, + IncompleteToolCall, ModelHTTPError, ModelRetry, UnexpectedModelBehavior, @@ -124,6 +125,7 @@ 'ModelRetry', 'ModelHTTPError', 'FallbackExceptionGroup', + 'IncompleteToolCall', 'UnexpectedModelBehavior', 'UsageLimitExceeded', 'UserError', diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 3867473d1f..f0733a8abe 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -92,9 +92,28 @@ class GraphAgentState: retries: int = 0 run_step: int = 0 - def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None: + def increment_retries( + self, + max_result_retries: int, + error: BaseException | None = None, + model_settings: ModelSettings | None = None, + ) -> None: self.retries += 1 if self.retries > max_result_retries: + if ( + self.message_history + and isinstance(model_response := self.message_history[-1], _messages.ModelResponse) + and model_response.finish_reason == 'length' + and model_response.parts + and isinstance(tool_call := model_response.parts[-1], _messages.ToolCallPart) + ): + try: + tool_call.args_as_dict() + except Exception: + max_tokens = (model_settings or {}).get('max_tokens') if model_settings else None + raise exceptions.IncompleteToolCall( + f'Model token limit ({max_tokens if max_tokens is not None else "provider default"}) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.' + ) message = f'Exceeded maximum retries ({max_result_retries}) for output validation' if error: if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None: @@ -568,7 +587,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa # resubmit the most recent request that resulted in an empty response, # as the empty response and request will not create any items in the API payload, # in the hope the model will return a non-empty response this time. - ctx.state.increment_retries(ctx.deps.max_result_retries) + ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings) self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[])) return @@ -630,7 +649,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa ) raise ToolRetryError(m) except ToolRetryError as e: - ctx.state.increment_retries(ctx.deps.max_result_retries, e) + ctx.state.increment_retries( + ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings + ) self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) self._events_iterator = _run_stream() @@ -788,10 +809,14 @@ async def process_tool_calls( # noqa: C901 try: result_data = await tool_manager.handle_call(call) except exceptions.UnexpectedModelBehavior as e: - ctx.state.increment_retries(ctx.deps.max_result_retries, e) + ctx.state.increment_retries( + ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings + ) raise e # pragma: lax no cover except ToolRetryError as e: - ctx.state.increment_retries(ctx.deps.max_result_retries, e) + ctx.state.increment_retries( + ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings + ) yield _messages.FunctionToolCallEvent(call) output_parts.append(e.tool_retry) yield _messages.FunctionToolResultEvent(e.tool_retry) @@ -820,7 +845,7 @@ async def process_tool_calls( # noqa: C901 # Then, we handle unknown tool calls if tool_calls_by_kind['unknown']: - ctx.state.increment_retries(ctx.deps.max_result_retries) + ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings) calls_to_run.extend(tool_calls_by_kind['unknown']) calls_to_run_results: dict[str, DeferredToolResult] = {} diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index 58a7686e06..d3d2672083 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -23,6 +23,7 @@ 'UnexpectedModelBehavior', 'UsageLimitExceeded', 'ModelHTTPError', + 'IncompleteToolCall', 'FallbackExceptionGroup', ) @@ -168,3 +169,7 @@ class ToolRetryError(Exception): def __init__(self, tool_retry: RetryPromptPart): self.tool_retry = tool_retry super().__init__() + + +class IncompleteToolCall(UnexpectedModelBehavior): + """Error raised when a model stops due to token limit while emitting a tool call.""" diff --git a/tests/test_agent.py b/tests/test_agent.py index c8beb08312..a0b271c3f3 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -27,6 +27,7 @@ DocumentUrl, FunctionToolset, ImageUrl, + IncompleteToolCall, ModelMessage, ModelMessagesTypeAdapter, ModelProfile, @@ -63,6 +64,7 @@ from pydantic_ai.output import StructuredDict, ToolOutput from pydantic_ai.result import RunUsage from pydantic_ai.run import AgentRunResultEvent +from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import DeferredToolRequests, DeferredToolResults, ToolDefinition, ToolDenied from pydantic_ai.usage import RequestUsage @@ -2448,6 +2450,45 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ) +def test_tool_exceeds_token_limit_error(): + def return_incomplete_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + resp = ModelResponse(parts=[ToolCallPart('dummy_tool', args='{"foo": "bar",')]) + resp.finish_reason = 'length' + return resp + + agent = Agent(FunctionModel(return_incomplete_tool), output_type=str) + + with pytest.raises( + IncompleteToolCall, + match=r'Model token limit \(10\) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.', + ): + agent.run_sync('Hello', model_settings=ModelSettings(max_tokens=10)) + + with pytest.raises( + IncompleteToolCall, + match=r'Model token limit \(provider default\) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.', + ): + agent.run_sync('Hello') + + +def test_tool_exceeds_token_limit_but_complete_args(): + def return_complete_tool_but_hit_limit(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + resp = ModelResponse(parts=[ToolCallPart('dummy_tool', args='{"foo": "bar"}')]) + resp.finish_reason = 'length' + return resp + return ModelResponse(parts=[TextPart('done')]) + + agent = Agent(FunctionModel(return_complete_tool_but_hit_limit), output_type=str) + + @agent.tool_plain + def dummy_tool(foo: str) -> str: + return 'tool-ok' + + result = agent.run_sync('Hello') + assert result.output == 'done' + + def test_model_requests_blocked(env: TestEnv): try: env.set('GEMINI_API_KEY', 'foobar')