Skip to content

Commit 175ef81

Browse files
authored
Handle Groq tool_use_failed errors by getting model to retry (#2774)
1 parent df14088 commit 175ef81

File tree

5 files changed

+1254
-34
lines changed

5 files changed

+1254
-34
lines changed

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 94 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from datetime import datetime
88
from typing import Any, Literal, cast, overload
99

10+
from pydantic import BaseModel, Json, ValidationError
1011
from typing_extensions import assert_never
1112

1213
from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
@@ -50,7 +51,7 @@
5051
)
5152

5253
try:
53-
from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
54+
from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream
5455
from groq.types import chat
5556
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
5657
except ImportError as _import_error:
@@ -171,9 +172,24 @@ async def request(
171172
model_request_parameters: ModelRequestParameters,
172173
) -> ModelResponse:
173174
check_allow_model_requests()
174-
response = await self._completions_create(
175-
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
176-
)
175+
try:
176+
response = await self._completions_create(
177+
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
178+
)
179+
except ModelHTTPError as e:
180+
if isinstance(e.body, dict): # pragma: no branch
181+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
182+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call.
183+
try:
184+
error = _GroqToolUseFailedError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
185+
tool_call_part = ToolCallPart(
186+
tool_name=error.error.failed_generation.name,
187+
args=error.error.failed_generation.arguments,
188+
)
189+
return ModelResponse(parts=[tool_call_part])
190+
except ValidationError:
191+
pass
192+
raise
177193
model_response = self._process_response(response)
178194
return model_response
179195

@@ -477,36 +493,52 @@ class GroqStreamedResponse(StreamedResponse):
477493
_provider_name: str
478494

479495
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
480-
async for chunk in self._response:
481-
self._usage += _map_usage(chunk)
482-
483-
try:
484-
choice = chunk.choices[0]
485-
except IndexError:
486-
continue
487-
488-
# Handle the text part of the response
489-
content = choice.delta.content
490-
if content is not None:
491-
maybe_event = self._parts_manager.handle_text_delta(
492-
vendor_part_id='content',
493-
content=content,
494-
thinking_tags=self._model_profile.thinking_tags,
495-
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
496-
)
497-
if maybe_event is not None: # pragma: no branch
498-
yield maybe_event
499-
500-
# Handle the tool calls
501-
for dtc in choice.delta.tool_calls or []:
502-
maybe_event = self._parts_manager.handle_tool_call_delta(
503-
vendor_part_id=dtc.index,
504-
tool_name=dtc.function and dtc.function.name,
505-
args=dtc.function and dtc.function.arguments,
506-
tool_call_id=dtc.id,
507-
)
508-
if maybe_event is not None:
509-
yield maybe_event
496+
try:
497+
async for chunk in self._response:
498+
self._usage += _map_usage(chunk)
499+
500+
try:
501+
choice = chunk.choices[0]
502+
except IndexError:
503+
continue
504+
505+
# Handle the text part of the response
506+
content = choice.delta.content
507+
if content is not None:
508+
maybe_event = self._parts_manager.handle_text_delta(
509+
vendor_part_id='content',
510+
content=content,
511+
thinking_tags=self._model_profile.thinking_tags,
512+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
513+
)
514+
if maybe_event is not None: # pragma: no branch
515+
yield maybe_event
516+
517+
# Handle the tool calls
518+
for dtc in choice.delta.tool_calls or []:
519+
maybe_event = self._parts_manager.handle_tool_call_delta(
520+
vendor_part_id=dtc.index,
521+
tool_name=dtc.function and dtc.function.name,
522+
args=dtc.function and dtc.function.arguments,
523+
tool_call_id=dtc.id,
524+
)
525+
if maybe_event is not None:
526+
yield maybe_event
527+
except APIError as e:
528+
if isinstance(e.body, dict): # pragma: no branch
529+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
530+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call
531+
try:
532+
error = _GroqToolUseFailedInnerError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
533+
yield self._parts_manager.handle_tool_call_part(
534+
vendor_part_id='tool_use_failed',
535+
tool_name=error.failed_generation.name,
536+
args=error.failed_generation.arguments,
537+
)
538+
return
539+
except ValidationError as e: # pragma: no cover
540+
pass
541+
raise # pragma: no cover
510542

511543
@property
512544
def model_name(self) -> GroqModelName:
@@ -538,3 +570,31 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
538570
input_tokens=response_usage.prompt_tokens,
539571
output_tokens=response_usage.completion_tokens,
540572
)
573+
574+
575+
class _GroqToolUseFailedGeneration(BaseModel):
576+
name: str
577+
arguments: dict[str, Any]
578+
579+
580+
class _GroqToolUseFailedInnerError(BaseModel):
581+
message: str
582+
type: Literal['invalid_request_error']
583+
code: Literal['tool_use_failed']
584+
failed_generation: Json[_GroqToolUseFailedGeneration]
585+
586+
587+
class _GroqToolUseFailedError(BaseModel):
588+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
589+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call.
590+
# Example payload from `exception.body`:
591+
# {
592+
# 'error': {
593+
# 'message': "Tool call validation failed: tool call validation failed: parameters for tool get_something_by_name did not match schema: errors: [missing properties: 'name', additionalProperties 'foo' not allowed]",
594+
# 'type': 'invalid_request_error',
595+
# 'code': 'tool_use_failed',
596+
# 'failed_generation': '{"name": "get_something_by_name", "arguments": {\n "foo": "bar"\n}}',
597+
# }
598+
# }
599+
600+
error: _GroqToolUseFailedInnerError
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- application/json
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '92'
12+
content-type:
13+
- application/json
14+
host:
15+
- api.groq.com
16+
method: POST
17+
parsed_body:
18+
messages:
19+
- content: hello
20+
role: user
21+
model: non-existent
22+
n: 1
23+
stream: false
24+
uri: https://api.groq.com/openai/v1/chat/completions
25+
response:
26+
headers:
27+
alt-svc:
28+
- h3=":443"; ma=86400
29+
cache-control:
30+
- private, max-age=0, no-store, no-cache, must-revalidate
31+
connection:
32+
- keep-alive
33+
content-length:
34+
- '153'
35+
content-type:
36+
- application/json
37+
transfer-encoding:
38+
- chunked
39+
vary:
40+
- Origin
41+
parsed_body:
42+
error:
43+
code: model_not_found
44+
message: The model `non-existent` does not exist or you do not have access to it.
45+
type: invalid_request_error
46+
status:
47+
code: 404
48+
message: Not Found
49+
- request:
50+
headers:
51+
accept:
52+
- application/json
53+
accept-encoding:
54+
- gzip, deflate
55+
connection:
56+
- keep-alive
57+
content-length:
58+
- '92'
59+
content-type:
60+
- application/json
61+
cookie:
62+
- __cf_bm=iezhhe87vg5KFYcM9IscO.LHkV8Rv0iM504X.mfCjyk-1756847105-1.0.1.1-otC4FjCOjx._h6QtwT8RmzJksvvp7AaYmv5_.yVAgRa9R4Aon_.qtG1HP7KVwunhgrIiWaZPQhPwh81YLf3h9rC_loDrBn.TQgDMlZ19B0w
63+
host:
64+
- api.groq.com
65+
method: POST
66+
parsed_body:
67+
messages:
68+
- content: hello
69+
role: user
70+
model: non-existent
71+
n: 1
72+
stream: false
73+
uri: https://api.groq.com/openai/v1/chat/completions
74+
response:
75+
headers:
76+
alt-svc:
77+
- h3=":443"; ma=86400
78+
cache-control:
79+
- private, max-age=0, no-store, no-cache, must-revalidate
80+
connection:
81+
- keep-alive
82+
content-length:
83+
- '153'
84+
content-type:
85+
- application/json
86+
transfer-encoding:
87+
- chunked
88+
vary:
89+
- Origin
90+
parsed_body:
91+
error:
92+
code: model_not_found
93+
message: The model `non-existent` does not exist or you do not have access to it.
94+
type: invalid_request_error
95+
status:
96+
code: 404
97+
message: Not Found
98+
version: 1

0 commit comments

Comments
 (0)