Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 87 additions & 34 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import datetime
from typing import Any, Literal, cast, overload

from pydantic import BaseModel, Json, ValidationError
from typing_extensions import assert_never

from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
Expand Down Expand Up @@ -50,7 +51,7 @@
)

try:
from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream
from groq.types import chat
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
except ImportError as _import_error:
Expand Down Expand Up @@ -171,9 +172,22 @@ async def request(
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
check_allow_model_requests()
response = await self._completions_create(
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
)
try:
response = await self._completions_create(
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
)
except ModelHTTPError as e:
if isinstance(e.body, dict): # pragma: no branch
try:
error = _GroqToolUseFailedError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
tool_call_part = ToolCallPart(
tool_name=error.error.failed_generation.name,
args=error.error.failed_generation.arguments,
)
return ModelResponse(parts=[tool_call_part])
except ValidationError:
pass
raise
model_response = self._process_response(response)
return model_response

Expand Down Expand Up @@ -477,36 +491,50 @@ class GroqStreamedResponse(StreamedResponse):
_provider_name: str

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
async for chunk in self._response:
self._usage += _map_usage(chunk)

try:
choice = chunk.choices[0]
except IndexError:
continue

# Handle the text part of the response
content = choice.delta.content
if content is not None:
maybe_event = self._parts_manager.handle_text_delta(
vendor_part_id='content',
content=content,
thinking_tags=self._model_profile.thinking_tags,
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
)
if maybe_event is not None: # pragma: no branch
yield maybe_event

# Handle the tool calls
for dtc in choice.delta.tool_calls or []:
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=dtc.index,
tool_name=dtc.function and dtc.function.name,
args=dtc.function and dtc.function.arguments,
tool_call_id=dtc.id,
)
if maybe_event is not None:
yield maybe_event
try:
async for chunk in self._response:
self._usage += _map_usage(chunk)

try:
choice = chunk.choices[0]
except IndexError:
continue

# Handle the text part of the response
content = choice.delta.content
if content is not None:
maybe_event = self._parts_manager.handle_text_delta(
vendor_part_id='content',
content=content,
thinking_tags=self._model_profile.thinking_tags,
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
)
if maybe_event is not None: # pragma: no branch
yield maybe_event

# Handle the tool calls
for dtc in choice.delta.tool_calls or []:
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=dtc.index,
tool_name=dtc.function and dtc.function.name,
args=dtc.function and dtc.function.arguments,
tool_call_id=dtc.id,
)
if maybe_event is not None:
yield maybe_event
except APIError as e:
if isinstance(e.body, dict): # pragma: no branch
try:
error = _GroqToolUseFailedInnerError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
yield self._parts_manager.handle_tool_call_part(
vendor_part_id='tool_use_failed',
tool_name=error.failed_generation.name,
args=error.failed_generation.arguments,
)
return
except ValidationError as e: # pragma: no cover
pass
raise # pragma: no cover

@property
def model_name(self) -> GroqModelName:
Expand Down Expand Up @@ -538,3 +566,28 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
input_tokens=response_usage.prompt_tokens,
output_tokens=response_usage.completion_tokens,
)


class _GroqToolUseFailedGeneration(BaseModel):
name: str
arguments: dict[str, Any]


class _GroqToolUseFailedInnerError(BaseModel):
message: str
type: Literal['invalid_request_error']
code: Literal['tool_use_failed']
failed_generation: Json[_GroqToolUseFailedGeneration]


class _GroqToolUseFailedError(BaseModel):
# {
# 'error': {
# 'message': "Tool call validation failed: tool call validation failed: parameters for tool get_something_by_name did not match schema: errors: [missing properties: 'name', additionalProperties 'foo' not allowed]",
# 'type': 'invalid_request_error',
# 'code': 'tool_use_failed',
# 'failed_generation': '{"name": "get_something_by_name", "arguments": {\n "foo": "bar"\n}}',
# }
# }

error: _GroqToolUseFailedInnerError
98 changes: 98 additions & 0 deletions tests/models/cassettes/test_groq/test_tool_regular_error.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
interactions:
- request:
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '92'
content-type:
- application/json
host:
- api.groq.com
method: POST
parsed_body:
messages:
- content: hello
role: user
model: non-existent
n: 1
stream: false
uri: https://api.groq.com/openai/v1/chat/completions
response:
headers:
alt-svc:
- h3=":443"; ma=86400
cache-control:
- private, max-age=0, no-store, no-cache, must-revalidate
connection:
- keep-alive
content-length:
- '153'
content-type:
- application/json
transfer-encoding:
- chunked
vary:
- Origin
parsed_body:
error:
code: model_not_found
message: The model `non-existent` does not exist or you do not have access to it.
type: invalid_request_error
status:
code: 404
message: Not Found
- request:
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '92'
content-type:
- application/json
cookie:
- __cf_bm=iezhhe87vg5KFYcM9IscO.LHkV8Rv0iM504X.mfCjyk-1756847105-1.0.1.1-otC4FjCOjx._h6QtwT8RmzJksvvp7AaYmv5_.yVAgRa9R4Aon_.qtG1HP7KVwunhgrIiWaZPQhPwh81YLf3h9rC_loDrBn.TQgDMlZ19B0w
host:
- api.groq.com
method: POST
parsed_body:
messages:
- content: hello
role: user
model: non-existent
n: 1
stream: false
uri: https://api.groq.com/openai/v1/chat/completions
response:
headers:
alt-svc:
- h3=":443"; ma=86400
cache-control:
- private, max-age=0, no-store, no-cache, must-revalidate
connection:
- keep-alive
content-length:
- '153'
content-type:
- application/json
transfer-encoding:
- chunked
vary:
- Origin
parsed_body:
error:
code: model_not_found
message: The model `non-existent` does not exist or you do not have access to it.
type: invalid_request_error
status:
code: 404
message: Not Found
version: 1
Loading