|
7 | 7 | from datetime import datetime |
8 | 8 | from typing import Any, Literal, cast, overload |
9 | 9 |
|
| 10 | +from pydantic import BaseModel, Json, ValidationError |
10 | 11 | from typing_extensions import assert_never |
11 | 12 |
|
12 | 13 | from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage |
|
48 | 49 | ) |
49 | 50 |
|
50 | 51 | try: |
51 | | - from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream |
| 52 | + from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream |
52 | 53 | from groq.types import chat |
53 | 54 | from groq.types.chat.chat_completion_content_part_image_param import ImageURL |
54 | 55 | except ImportError as _import_error: |
@@ -169,9 +170,22 @@ async def request( |
169 | 170 | model_request_parameters: ModelRequestParameters, |
170 | 171 | ) -> ModelResponse: |
171 | 172 | check_allow_model_requests() |
172 | | - response = await self._completions_create( |
173 | | - messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters |
174 | | - ) |
| 173 | + try: |
| 174 | + response = await self._completions_create( |
| 175 | + messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters |
| 176 | + ) |
| 177 | + except ModelHTTPError as e: |
| 178 | + if isinstance(e.body, dict): # pragma: no branch |
| 179 | + try: |
| 180 | + error = _GroqToolUseFailedError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType] |
| 181 | + tool_call_part = ToolCallPart( |
| 182 | + tool_name=error.error.failed_generation.name, |
| 183 | + args=error.error.failed_generation.arguments, |
| 184 | + ) |
| 185 | + return ModelResponse(parts=[tool_call_part]) |
| 186 | + except ValidationError: |
| 187 | + pass |
| 188 | + raise |
175 | 189 | model_response = self._process_response(response) |
176 | 190 | return model_response |
177 | 191 |
|
@@ -449,36 +463,50 @@ class GroqStreamedResponse(StreamedResponse): |
449 | 463 | _provider_name: str |
450 | 464 |
|
451 | 465 | async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: |
452 | | - async for chunk in self._response: |
453 | | - self._usage += _map_usage(chunk) |
454 | | - |
455 | | - try: |
456 | | - choice = chunk.choices[0] |
457 | | - except IndexError: |
458 | | - continue |
459 | | - |
460 | | - # Handle the text part of the response |
461 | | - content = choice.delta.content |
462 | | - if content is not None: |
463 | | - maybe_event = self._parts_manager.handle_text_delta( |
464 | | - vendor_part_id='content', |
465 | | - content=content, |
466 | | - thinking_tags=self._model_profile.thinking_tags, |
467 | | - ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, |
468 | | - ) |
469 | | - if maybe_event is not None: # pragma: no branch |
470 | | - yield maybe_event |
471 | | - |
472 | | - # Handle the tool calls |
473 | | - for dtc in choice.delta.tool_calls or []: |
474 | | - maybe_event = self._parts_manager.handle_tool_call_delta( |
475 | | - vendor_part_id=dtc.index, |
476 | | - tool_name=dtc.function and dtc.function.name, |
477 | | - args=dtc.function and dtc.function.arguments, |
478 | | - tool_call_id=dtc.id, |
479 | | - ) |
480 | | - if maybe_event is not None: |
481 | | - yield maybe_event |
| 466 | + try: |
| 467 | + async for chunk in self._response: |
| 468 | + self._usage += _map_usage(chunk) |
| 469 | + |
| 470 | + try: |
| 471 | + choice = chunk.choices[0] |
| 472 | + except IndexError: |
| 473 | + continue |
| 474 | + |
| 475 | + # Handle the text part of the response |
| 476 | + content = choice.delta.content |
| 477 | + if content is not None: |
| 478 | + maybe_event = self._parts_manager.handle_text_delta( |
| 479 | + vendor_part_id='content', |
| 480 | + content=content, |
| 481 | + thinking_tags=self._model_profile.thinking_tags, |
| 482 | + ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, |
| 483 | + ) |
| 484 | + if maybe_event is not None: # pragma: no branch |
| 485 | + yield maybe_event |
| 486 | + |
| 487 | + # Handle the tool calls |
| 488 | + for dtc in choice.delta.tool_calls or []: |
| 489 | + maybe_event = self._parts_manager.handle_tool_call_delta( |
| 490 | + vendor_part_id=dtc.index, |
| 491 | + tool_name=dtc.function and dtc.function.name, |
| 492 | + args=dtc.function and dtc.function.arguments, |
| 493 | + tool_call_id=dtc.id, |
| 494 | + ) |
| 495 | + if maybe_event is not None: |
| 496 | + yield maybe_event |
| 497 | + except APIError as e: |
| 498 | + if isinstance(e.body, dict): # pragma: no branch |
| 499 | + try: |
| 500 | + error = _GroqToolUseFailedInnerError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType] |
| 501 | + yield self._parts_manager.handle_tool_call_part( |
| 502 | + vendor_part_id='tool_use_failed', |
| 503 | + tool_name=error.failed_generation.name, |
| 504 | + args=error.failed_generation.arguments, |
| 505 | + ) |
| 506 | + return |
| 507 | + except ValidationError as e: |
| 508 | + pass |
| 509 | + raise |
482 | 510 |
|
483 | 511 | @property |
484 | 512 | def model_name(self) -> GroqModelName: |
@@ -510,3 +538,28 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us |
510 | 538 | input_tokens=response_usage.prompt_tokens, |
511 | 539 | output_tokens=response_usage.completion_tokens, |
512 | 540 | ) |
| 541 | + |
| 542 | + |
| 543 | +class _GroqToolUseFailedGeneration(BaseModel): |
| 544 | + name: str |
| 545 | + arguments: dict[str, Any] |
| 546 | + |
| 547 | + |
| 548 | +class _GroqToolUseFailedInnerError(BaseModel): |
| 549 | + message: str |
| 550 | + type: Literal['invalid_request_error'] |
| 551 | + code: Literal['tool_use_failed'] |
| 552 | + failed_generation: Json[_GroqToolUseFailedGeneration] |
| 553 | + |
| 554 | + |
| 555 | +class _GroqToolUseFailedError(BaseModel): |
| 556 | + # { |
| 557 | + # 'error': { |
| 558 | + # 'message': "Tool call validation failed: tool call validation failed: parameters for tool get_something_by_name did not match schema: errors: [missing properties: 'name', additionalProperties 'foo' not allowed]", |
| 559 | + # 'type': 'invalid_request_error', |
| 560 | + # 'code': 'tool_use_failed', |
| 561 | + # 'failed_generation': '{"name": "get_something_by_name", "arguments": {\n "foo": "bar"\n}}', |
| 562 | + # } |
| 563 | + # } |
| 564 | + |
| 565 | + error: _GroqToolUseFailedInnerError |
0 commit comments