Skip to content

Commit 2c5cb9e

Browse files
committed
Handle Groq tool_use_failed errors by telling model to retry as usual
1 parent 3c2624e commit 2c5cb9e

File tree

5 files changed

+992
-34
lines changed

5 files changed

+992
-34
lines changed

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 87 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from datetime import datetime
88
from typing import Any, Literal, cast, overload
99

10+
from pydantic import BaseModel, Json, ValidationError
1011
from typing_extensions import assert_never
1112

1213
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
@@ -48,7 +49,7 @@
4849
)
4950

5051
try:
51-
from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
52+
from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream
5253
from groq.types import chat
5354
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
5455
except ImportError as _import_error:
@@ -169,9 +170,22 @@ async def request(
169170
model_request_parameters: ModelRequestParameters,
170171
) -> ModelResponse:
171172
check_allow_model_requests()
172-
response = await self._completions_create(
173-
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
174-
)
173+
try:
174+
response = await self._completions_create(
175+
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
176+
)
177+
except ModelHTTPError as e:
178+
if isinstance(e.body, dict): # pragma: no branch
179+
try:
180+
error = _GroqToolUseFailedError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
181+
tool_call_part = ToolCallPart(
182+
tool_name=error.error.failed_generation.name,
183+
args=error.error.failed_generation.arguments,
184+
)
185+
return ModelResponse(parts=[tool_call_part])
186+
except ValidationError:
187+
pass
188+
raise
175189
model_response = self._process_response(response)
176190
return model_response
177191

@@ -449,36 +463,50 @@ class GroqStreamedResponse(StreamedResponse):
449463
_provider_name: str
450464

451465
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
452-
async for chunk in self._response:
453-
self._usage += _map_usage(chunk)
454-
455-
try:
456-
choice = chunk.choices[0]
457-
except IndexError:
458-
continue
459-
460-
# Handle the text part of the response
461-
content = choice.delta.content
462-
if content is not None:
463-
maybe_event = self._parts_manager.handle_text_delta(
464-
vendor_part_id='content',
465-
content=content,
466-
thinking_tags=self._model_profile.thinking_tags,
467-
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
468-
)
469-
if maybe_event is not None: # pragma: no branch
470-
yield maybe_event
471-
472-
# Handle the tool calls
473-
for dtc in choice.delta.tool_calls or []:
474-
maybe_event = self._parts_manager.handle_tool_call_delta(
475-
vendor_part_id=dtc.index,
476-
tool_name=dtc.function and dtc.function.name,
477-
args=dtc.function and dtc.function.arguments,
478-
tool_call_id=dtc.id,
479-
)
480-
if maybe_event is not None:
481-
yield maybe_event
466+
try:
467+
async for chunk in self._response:
468+
self._usage += _map_usage(chunk)
469+
470+
try:
471+
choice = chunk.choices[0]
472+
except IndexError:
473+
continue
474+
475+
# Handle the text part of the response
476+
content = choice.delta.content
477+
if content is not None:
478+
maybe_event = self._parts_manager.handle_text_delta(
479+
vendor_part_id='content',
480+
content=content,
481+
thinking_tags=self._model_profile.thinking_tags,
482+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
483+
)
484+
if maybe_event is not None: # pragma: no branch
485+
yield maybe_event
486+
487+
# Handle the tool calls
488+
for dtc in choice.delta.tool_calls or []:
489+
maybe_event = self._parts_manager.handle_tool_call_delta(
490+
vendor_part_id=dtc.index,
491+
tool_name=dtc.function and dtc.function.name,
492+
args=dtc.function and dtc.function.arguments,
493+
tool_call_id=dtc.id,
494+
)
495+
if maybe_event is not None:
496+
yield maybe_event
497+
except APIError as e:
498+
if isinstance(e.body, dict): # pragma: no branch
499+
try:
500+
error = _GroqToolUseFailedInnerError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
501+
yield self._parts_manager.handle_tool_call_part(
502+
vendor_part_id='tool_use_failed',
503+
tool_name=error.failed_generation.name,
504+
args=error.failed_generation.arguments,
505+
)
506+
return
507+
except ValidationError as e:
508+
pass
509+
raise
482510

483511
@property
484512
def model_name(self) -> GroqModelName:
@@ -510,3 +538,28 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
510538
input_tokens=response_usage.prompt_tokens,
511539
output_tokens=response_usage.completion_tokens,
512540
)
541+
542+
543+
class _GroqToolUseFailedGeneration(BaseModel):
544+
name: str
545+
arguments: dict[str, Any]
546+
547+
548+
class _GroqToolUseFailedInnerError(BaseModel):
549+
message: str
550+
type: Literal['invalid_request_error']
551+
code: Literal['tool_use_failed']
552+
failed_generation: Json[_GroqToolUseFailedGeneration]
553+
554+
555+
class _GroqToolUseFailedError(BaseModel):
556+
# {
557+
# 'error': {
558+
# 'message': "Tool call validation failed: tool call validation failed: parameters for tool get_something_by_name did not match schema: errors: [missing properties: 'name', additionalProperties 'foo' not allowed]",
559+
# 'type': 'invalid_request_error',
560+
# 'code': 'tool_use_failed',
561+
# 'failed_generation': '{"name": "get_something_by_name", "arguments": {\n "foo": "bar"\n}}',
562+
# }
563+
# }
564+
565+
error: _GroqToolUseFailedInnerError
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- application/json
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '92'
12+
content-type:
13+
- application/json
14+
host:
15+
- api.groq.com
16+
method: POST
17+
parsed_body:
18+
messages:
19+
- content: hello
20+
role: user
21+
model: non-existent
22+
n: 1
23+
stream: false
24+
uri: https://api.groq.com/openai/v1/chat/completions
25+
response:
26+
headers:
27+
alt-svc:
28+
- h3=":443"; ma=86400
29+
cache-control:
30+
- private, max-age=0, no-store, no-cache, must-revalidate
31+
connection:
32+
- keep-alive
33+
content-length:
34+
- '153'
35+
content-type:
36+
- application/json
37+
transfer-encoding:
38+
- chunked
39+
vary:
40+
- Origin
41+
parsed_body:
42+
error:
43+
code: model_not_found
44+
message: The model `non-existent` does not exist or you do not have access to it.
45+
type: invalid_request_error
46+
status:
47+
code: 404
48+
message: Not Found
49+
version: 1

0 commit comments

Comments
 (0)