|
7 | 7 | from datetime import datetime
|
8 | 8 | from typing import Any, Literal, cast, overload
|
9 | 9 |
|
| 10 | +from pydantic import BaseModel, Json, ValidationError |
10 | 11 | from typing_extensions import assert_never
|
11 | 12 |
|
12 | 13 | from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
|
|
50 | 51 | )
|
51 | 52 |
|
52 | 53 | try:
|
53 |
| - from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream |
| 54 | + from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream |
54 | 55 | from groq.types import chat
|
55 | 56 | from groq.types.chat.chat_completion_content_part_image_param import ImageURL
|
56 | 57 | except ImportError as _import_error:
|
@@ -171,9 +172,24 @@ async def request(
|
171 | 172 | model_request_parameters: ModelRequestParameters,
|
172 | 173 | ) -> ModelResponse:
|
173 | 174 | check_allow_model_requests()
|
174 |
| - response = await self._completions_create( |
175 |
| - messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters |
176 |
| - ) |
| 175 | + try: |
| 176 | + response = await self._completions_create( |
| 177 | + messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters |
| 178 | + ) |
| 179 | + except ModelHTTPError as e: |
| 180 | + if isinstance(e.body, dict): # pragma: no branch |
| 181 | + # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema, |
| 182 | + # but we'd rather handle it ourselves so we can tell the model to retry the tool call. |
| 183 | + try: |
| 184 | + error = _GroqToolUseFailedError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType] |
| 185 | + tool_call_part = ToolCallPart( |
| 186 | + tool_name=error.error.failed_generation.name, |
| 187 | + args=error.error.failed_generation.arguments, |
| 188 | + ) |
| 189 | + return ModelResponse(parts=[tool_call_part]) |
| 190 | + except ValidationError: |
| 191 | + pass |
| 192 | + raise |
177 | 193 | model_response = self._process_response(response)
|
178 | 194 | return model_response
|
179 | 195 |
|
@@ -477,36 +493,52 @@ class GroqStreamedResponse(StreamedResponse):
|
477 | 493 | _provider_name: str
|
478 | 494 |
|
479 | 495 | async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
480 |
| - async for chunk in self._response: |
481 |
| - self._usage += _map_usage(chunk) |
482 |
| - |
483 |
| - try: |
484 |
| - choice = chunk.choices[0] |
485 |
| - except IndexError: |
486 |
| - continue |
487 |
| - |
488 |
| - # Handle the text part of the response |
489 |
| - content = choice.delta.content |
490 |
| - if content is not None: |
491 |
| - maybe_event = self._parts_manager.handle_text_delta( |
492 |
| - vendor_part_id='content', |
493 |
| - content=content, |
494 |
| - thinking_tags=self._model_profile.thinking_tags, |
495 |
| - ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, |
496 |
| - ) |
497 |
| - if maybe_event is not None: # pragma: no branch |
498 |
| - yield maybe_event |
499 |
| - |
500 |
| - # Handle the tool calls |
501 |
| - for dtc in choice.delta.tool_calls or []: |
502 |
| - maybe_event = self._parts_manager.handle_tool_call_delta( |
503 |
| - vendor_part_id=dtc.index, |
504 |
| - tool_name=dtc.function and dtc.function.name, |
505 |
| - args=dtc.function and dtc.function.arguments, |
506 |
| - tool_call_id=dtc.id, |
507 |
| - ) |
508 |
| - if maybe_event is not None: |
509 |
| - yield maybe_event |
| 496 | + try: |
| 497 | + async for chunk in self._response: |
| 498 | + self._usage += _map_usage(chunk) |
| 499 | + |
| 500 | + try: |
| 501 | + choice = chunk.choices[0] |
| 502 | + except IndexError: |
| 503 | + continue |
| 504 | + |
| 505 | + # Handle the text part of the response |
| 506 | + content = choice.delta.content |
| 507 | + if content is not None: |
| 508 | + maybe_event = self._parts_manager.handle_text_delta( |
| 509 | + vendor_part_id='content', |
| 510 | + content=content, |
| 511 | + thinking_tags=self._model_profile.thinking_tags, |
| 512 | + ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, |
| 513 | + ) |
| 514 | + if maybe_event is not None: # pragma: no branch |
| 515 | + yield maybe_event |
| 516 | + |
| 517 | + # Handle the tool calls |
| 518 | + for dtc in choice.delta.tool_calls or []: |
| 519 | + maybe_event = self._parts_manager.handle_tool_call_delta( |
| 520 | + vendor_part_id=dtc.index, |
| 521 | + tool_name=dtc.function and dtc.function.name, |
| 522 | + args=dtc.function and dtc.function.arguments, |
| 523 | + tool_call_id=dtc.id, |
| 524 | + ) |
| 525 | + if maybe_event is not None: |
| 526 | + yield maybe_event |
| 527 | + except APIError as e: |
| 528 | + if isinstance(e.body, dict): # pragma: no branch |
| 529 | + # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema, |
| 530 | + # but we'd rather handle it ourselves so we can tell the model to retry the tool call |
| 531 | + try: |
| 532 | + error = _GroqToolUseFailedInnerError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType] |
| 533 | + yield self._parts_manager.handle_tool_call_part( |
| 534 | + vendor_part_id='tool_use_failed', |
| 535 | + tool_name=error.failed_generation.name, |
| 536 | + args=error.failed_generation.arguments, |
| 537 | + ) |
| 538 | + return |
| 539 | + except ValidationError as e: # pragma: no cover |
| 540 | + pass |
| 541 | + raise # pragma: no cover |
510 | 542 |
|
511 | 543 | @property
|
512 | 544 | def model_name(self) -> GroqModelName:
|
@@ -538,3 +570,31 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
|
538 | 570 | input_tokens=response_usage.prompt_tokens,
|
539 | 571 | output_tokens=response_usage.completion_tokens,
|
540 | 572 | )
|
| 573 | + |
| 574 | + |
| 575 | +class _GroqToolUseFailedGeneration(BaseModel): |
| 576 | + name: str |
| 577 | + arguments: dict[str, Any] |
| 578 | + |
| 579 | + |
| 580 | +class _GroqToolUseFailedInnerError(BaseModel): |
| 581 | + message: str |
| 582 | + type: Literal['invalid_request_error'] |
| 583 | + code: Literal['tool_use_failed'] |
| 584 | + failed_generation: Json[_GroqToolUseFailedGeneration] |
| 585 | + |
| 586 | + |
| 587 | +class _GroqToolUseFailedError(BaseModel): |
| 588 | + # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema, |
| 589 | + # but we'd rather handle it ourselves so we can tell the model to retry the tool call. |
| 590 | + # Example payload from `exception.body`: |
| 591 | + # { |
| 592 | + # 'error': { |
| 593 | + # 'message': "Tool call validation failed: tool call validation failed: parameters for tool get_something_by_name did not match schema: errors: [missing properties: 'name', additionalProperties 'foo' not allowed]", |
| 594 | + # 'type': 'invalid_request_error', |
| 595 | + # 'code': 'tool_use_failed', |
| 596 | + # 'failed_generation': '{"name": "get_something_by_name", "arguments": {\n "foo": "bar"\n}}', |
| 597 | + # } |
| 598 | + # } |
| 599 | + |
| 600 | + error: _GroqToolUseFailedInnerError |
0 commit comments