Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from typing_extensions import assert_never

from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition

from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
from .._run_context import RunContext
from .._thinking_part import split_content_into_text_and_thinking
Expand Down Expand Up @@ -228,6 +230,18 @@ async def _completions_create(

groq_messages = self._map_messages(messages)

response_format: chat.completion_create_params.ResponseFormat | None = None
if model_request_parameters.output_mode == 'native':
output_object = model_request_parameters.output_object
assert output_object is not None
response_format = self._map_json_schema(output_object)
elif (
model_request_parameters.output_mode == 'prompted'
and not tools
and self.profile.supports_json_object_output
): # pragma: no branch
response_format = {'type': 'json_object'}

try:
extra_headers = model_settings.get('extra_headers', {})
extra_headers.setdefault('User-Agent', get_user_agent())
Expand All @@ -240,6 +254,7 @@ async def _completions_create(
tool_choice=tool_choice or NOT_GIVEN,
stop=model_settings.get('stop_sequences', NOT_GIVEN),
stream=stream,
response_format=response_format or NOT_GIVEN,
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
temperature=model_settings.get('temperature', NOT_GIVEN),
top_p=model_settings.get('top_p', NOT_GIVEN),
Expand Down Expand Up @@ -385,6 +400,19 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
},
}

def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
response_format_param: chat.completion_create_params.ResponseFormatResponseFormatJsonSchema = {
'type': 'json_schema',
'json_schema': {
'name': o.name or DEFAULT_OUTPUT_TOOL_NAME,
'schema': o.json_schema,
'strict': o.strict,
},
}
if o.description: # pragma: no branch
response_format_param['json_schema']['description'] = o.description
return response_format_param

@classmethod
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
for part in message.parts:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def _map_tool_call(t: ToolCallPart) -> ChatCompletionMessageFunctionToolCallPara
def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage]
'type': 'json_schema',
'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema, 'strict': True},
'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema},
}
if o.description:
response_format_param['json_schema']['description'] = o.description
Expand Down
23 changes: 21 additions & 2 deletions pydantic_ai_slim/pydantic_ai/providers/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pydantic_ai.profiles.meta import meta_model_profile
from pydantic_ai.profiles.mistral import mistral_model_profile
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
from pydantic_ai.profiles.openai import openai_model_profile
from pydantic_ai.profiles.qwen import qwen_model_profile
from pydantic_ai.providers import Provider

Expand All @@ -26,6 +27,23 @@
) from _import_error


def groq_moonshotai_model_profile(model_name: str) -> ModelProfile | None:
"""Get the model profile for an MoonshotAI model used with the Groq provider."""
return ModelProfile(supports_json_object_output=True, supports_json_schema_output=True).update(
moonshotai_model_profile(model_name)
)


def meta_groq_model_profile(model_name: str) -> ModelProfile | None:
"""Get the model profile for a Meta model used with the Groq provider."""
if model_name in {'llama-4-maverick-17b-128e-instruct', 'llama-4-scout-17b-16e-instruct'}:
return ModelProfile(supports_json_object_output=True, supports_json_schema_output=True).update(
meta_model_profile(model_name)
)
else:
return meta_model_profile(model_name)


class GroqProvider(Provider[AsyncGroq]):
"""Provider for Groq API."""

Expand All @@ -44,13 +62,14 @@ def client(self) -> AsyncGroq:
def model_profile(self, model_name: str) -> ModelProfile | None:
prefix_to_profile = {
'llama': meta_model_profile,
'meta-llama/': meta_model_profile,
'meta-llama/': meta_groq_model_profile,
'gemma': google_model_profile,
'qwen': qwen_model_profile,
'deepseek': deepseek_model_profile,
'mistral': mistral_model_profile,
'moonshotai/': moonshotai_model_profile,
'moonshotai/': groq_moonshotai_model_profile,
'compound-': groq_model_profile,
'openai/': openai_model_profile,
}

for prefix, profile_func in prefix_to_profile.items():
Expand Down
89 changes: 89 additions & 0 deletions tests/models/cassettes/test_groq/test_groq_native_output.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
interactions:
- request:
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '416'
content-type:
- application/json
host:
- api.groq.com
method: POST
parsed_body:
messages:
- content: What is the largest city in Mexico?
role: user
model: openai/gpt-oss-120b
n: 1
response_format:
json_schema:
description: A city and its country.
name: CityLocation
schema:
additionalProperties: false
properties:
city:
type: string
country:
type: string
required:
- city
- country
type: object
strict: true
type: json_schema
stream: false
uri: https://api.groq.com/openai/v1/chat/completions
response:
headers:
alt-svc:
- h3=":443"; ma=86400
cache-control:
- private, max-age=0, no-store, no-cache, must-revalidate
connection:
- keep-alive
content-length:
- '947'
content-type:
- application/json
transfer-encoding:
- chunked
vary:
- Origin
parsed_body:
choices:
- finish_reason: stop
index: 0
logprobs: null
message:
content: '{"city":"Mexico City","country":"Mexico"}'
reasoning: 'The user asks: "What is the largest city in Mexico?" The system expects a JSON object conforming to
CityLocation schema: properties city (string) and country (string), required both. Provide largest city in Mexico:
Mexico City. So output JSON: {"city":"Mexico City","country":"Mexico"} in compact format, no extra text.'
role: assistant
created: 1756843265
id: chatcmpl-92437948-262c-49fe-87d1-774e54201105
model: openai/gpt-oss-120b
object: chat.completion
service_tier: on_demand
system_fingerprint: fp_213abb2467
usage:
completion_time: 0.186978247
completion_tokens: 94
prompt_time: 0.008149307
prompt_tokens: 178
queue_time: 1.095749809
total_time: 0.195127554
total_tokens: 272
usage_breakdown: null
x_groq:
id: req_01k4609j8zemga5d9erfnwbfma
status:
code: 200
message: OK
version: 1
81 changes: 81 additions & 0 deletions tests/models/cassettes/test_groq/test_groq_prompted_output.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
interactions:
- request:
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '516'
content-type:
- application/json
host:
- api.groq.com
method: POST
parsed_body:
messages:
- content: |-
Always respond with a JSON object that's compatible with this schema:

{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}

Don't include any text or Markdown fencing before or after.
role: system
- content: What is the largest city in Mexico?
role: user
model: openai/gpt-oss-120b
n: 1
response_format:
type: json_object
stream: false
uri: https://api.groq.com/openai/v1/chat/completions
response:
headers:
alt-svc:
- h3=":443"; ma=86400
cache-control:
- private, max-age=0, no-store, no-cache, must-revalidate
connection:
- keep-alive
content-length:
- '926'
content-type:
- application/json
transfer-encoding:
- chunked
vary:
- Origin
parsed_body:
choices:
- finish_reason: stop
index: 0
logprobs: null
message:
content: '{"city":"Mexico City","country":"Mexico"}'
reasoning: 'We need to respond with JSON object with properties city and country. The question: "What is the largest
city in Mexico?" The answer: City is Mexico City, country is Mexico. Must output compact JSON without any extra
text or markdown. So {"city":"Mexico City","country":"Mexico"} Ensure valid JSON.'
role: assistant
created: 1756843266
id: chatcmpl-d7085def-1e9f-45d7-b90b-65633ef23489
model: openai/gpt-oss-120b
object: chat.completion
service_tier: on_demand
system_fingerprint: fp_ed9190d8b7
usage:
completion_time: 0.173182068
completion_tokens: 87
prompt_time: 0.006958709
prompt_tokens: 177
queue_time: 0.212268627
total_time: 0.180140777
total_tokens: 264
usage_breakdown: null
x_groq:
id: req_01k4609m4zf9qsbf0wwgvw8398
status:
code: 200
message: OK
version: 1
89 changes: 89 additions & 0 deletions tests/models/test_groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest
from dirty_equals import IsListOrTuple
from inline_snapshot import snapshot
from pydantic import BaseModel
from typing_extensions import TypedDict

from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior
Expand All @@ -37,6 +38,7 @@
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.output import NativeOutput, PromptedOutput
from pydantic_ai.usage import RequestUsage, RunUsage

from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, raise_if_exception, try_import
Expand Down Expand Up @@ -1004,3 +1006,90 @@ async def test_groq_model_thinking_part_iter(allow_model_requests: None, groq_ap
length=996,
)
)


async def test_groq_native_output(allow_model_requests: None, groq_api_key: str):
m = GroqModel('openai/gpt-oss-120b', provider=GroqProvider(api_key=groq_api_key))

class CityLocation(BaseModel):
"""A city and its country."""

city: str
country: str

agent = Agent(m, output_type=NativeOutput(CityLocation))

result = await agent.run('What is the largest city in Mexico?')
assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico'))

assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content='What is the largest city in Mexico?',
timestamp=IsDatetime(),
)
]
),
ModelResponse(
parts=[
ThinkingPart(
content='The user asks: "What is the largest city in Mexico?" The system expects a JSON object conforming to CityLocation schema: properties city (string) and country (string), required both. Provide largest city in Mexico: Mexico City. So output JSON: {"city":"Mexico City","country":"Mexico"} in compact format, no extra text.'
),
TextPart(content='{"city":"Mexico City","country":"Mexico"}'),
],
usage=RequestUsage(input_tokens=178, output_tokens=94),
model_name='openai/gpt-oss-120b',
timestamp=IsDatetime(),
provider_name='groq',
provider_response_id=IsStr(),
),
]
)


async def test_groq_prompted_output(allow_model_requests: None, groq_api_key: str):
m = GroqModel('openai/gpt-oss-120b', provider=GroqProvider(api_key=groq_api_key))

class CityLocation(BaseModel):
city: str
country: str

agent = Agent(m, output_type=PromptedOutput(CityLocation))

result = await agent.run('What is the largest city in Mexico?')
assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico'))

assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content='What is the largest city in Mexico?',
timestamp=IsDatetime(),
)
],
instructions="""\
Always respond with a JSON object that's compatible with this schema:

{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}

Don't include any text or Markdown fencing before or after.\
""",
),
ModelResponse(
parts=[
ThinkingPart(
content='We need to respond with JSON object with properties city and country. The question: "What is the largest city in Mexico?" The answer: City is Mexico City, country is Mexico. Must output compact JSON without any extra text or markdown. So {"city":"Mexico City","country":"Mexico"} Ensure valid JSON.'
),
TextPart(content='{"city":"Mexico City","country":"Mexico"}'),
],
usage=RequestUsage(input_tokens=177, output_tokens=87),
model_name='openai/gpt-oss-120b',
timestamp=IsDatetime(),
provider_name='groq',
provider_response_id=IsStr(),
),
]
)
Loading