Skip to content

Commit 55ad5bd

Browse files
authored
Support NativeOutput with Groq (#2772)
1 parent 3c2624e commit 55ad5bd

File tree

7 files changed

+341
-3
lines changed

7 files changed

+341
-3
lines changed

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from typing_extensions import assert_never
1111

12+
from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
13+
1214
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
1315
from .._run_context import RunContext
1416
from .._thinking_part import split_content_into_text_and_thinking
@@ -228,6 +230,18 @@ async def _completions_create(
228230

229231
groq_messages = self._map_messages(messages)
230232

233+
response_format: chat.completion_create_params.ResponseFormat | None = None
234+
if model_request_parameters.output_mode == 'native':
235+
output_object = model_request_parameters.output_object
236+
assert output_object is not None
237+
response_format = self._map_json_schema(output_object)
238+
elif (
239+
model_request_parameters.output_mode == 'prompted'
240+
and not tools
241+
and self.profile.supports_json_object_output
242+
): # pragma: no branch
243+
response_format = {'type': 'json_object'}
244+
231245
try:
232246
extra_headers = model_settings.get('extra_headers', {})
233247
extra_headers.setdefault('User-Agent', get_user_agent())
@@ -240,6 +254,7 @@ async def _completions_create(
240254
tool_choice=tool_choice or NOT_GIVEN,
241255
stop=model_settings.get('stop_sequences', NOT_GIVEN),
242256
stream=stream,
257+
response_format=response_format or NOT_GIVEN,
243258
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
244259
temperature=model_settings.get('temperature', NOT_GIVEN),
245260
top_p=model_settings.get('top_p', NOT_GIVEN),
@@ -385,6 +400,19 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
385400
},
386401
}
387402

403+
def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
404+
response_format_param: chat.completion_create_params.ResponseFormatResponseFormatJsonSchema = {
405+
'type': 'json_schema',
406+
'json_schema': {
407+
'name': o.name or DEFAULT_OUTPUT_TOOL_NAME,
408+
'schema': o.json_schema,
409+
'strict': o.strict,
410+
},
411+
}
412+
if o.description: # pragma: no branch
413+
response_format_param['json_schema']['description'] = o.description
414+
return response_format_param
415+
388416
@classmethod
389417
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
390418
for part in message.parts:

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def _map_tool_call(t: ToolCallPart) -> ChatCompletionMessageFunctionToolCallPara
606606
def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
607607
response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage]
608608
'type': 'json_schema',
609-
'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema, 'strict': True},
609+
'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema},
610610
}
611611
if o.description:
612612
response_format_param['json_schema']['description'] = o.description

pydantic_ai_slim/pydantic_ai/providers/groq.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pydantic_ai.profiles.meta import meta_model_profile
1515
from pydantic_ai.profiles.mistral import mistral_model_profile
1616
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
17+
from pydantic_ai.profiles.openai import openai_model_profile
1718
from pydantic_ai.profiles.qwen import qwen_model_profile
1819
from pydantic_ai.providers import Provider
1920

@@ -26,6 +27,23 @@
2627
) from _import_error
2728

2829

30+
def groq_moonshotai_model_profile(model_name: str) -> ModelProfile | None:
31+
"""Get the model profile for an MoonshotAI model used with the Groq provider."""
32+
return ModelProfile(supports_json_object_output=True, supports_json_schema_output=True).update(
33+
moonshotai_model_profile(model_name)
34+
)
35+
36+
37+
def meta_groq_model_profile(model_name: str) -> ModelProfile | None:
38+
"""Get the model profile for a Meta model used with the Groq provider."""
39+
if model_name in {'llama-4-maverick-17b-128e-instruct', 'llama-4-scout-17b-16e-instruct'}:
40+
return ModelProfile(supports_json_object_output=True, supports_json_schema_output=True).update(
41+
meta_model_profile(model_name)
42+
)
43+
else:
44+
return meta_model_profile(model_name)
45+
46+
2947
class GroqProvider(Provider[AsyncGroq]):
3048
"""Provider for Groq API."""
3149

@@ -44,13 +62,14 @@ def client(self) -> AsyncGroq:
4462
def model_profile(self, model_name: str) -> ModelProfile | None:
4563
prefix_to_profile = {
4664
'llama': meta_model_profile,
47-
'meta-llama/': meta_model_profile,
65+
'meta-llama/': meta_groq_model_profile,
4866
'gemma': google_model_profile,
4967
'qwen': qwen_model_profile,
5068
'deepseek': deepseek_model_profile,
5169
'mistral': mistral_model_profile,
52-
'moonshotai/': moonshotai_model_profile,
70+
'moonshotai/': groq_moonshotai_model_profile,
5371
'compound-': groq_model_profile,
72+
'openai/': openai_model_profile,
5473
}
5574

5675
for prefix, profile_func in prefix_to_profile.items():
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- application/json
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '416'
12+
content-type:
13+
- application/json
14+
host:
15+
- api.groq.com
16+
method: POST
17+
parsed_body:
18+
messages:
19+
- content: What is the largest city in Mexico?
20+
role: user
21+
model: openai/gpt-oss-120b
22+
n: 1
23+
response_format:
24+
json_schema:
25+
description: A city and its country.
26+
name: CityLocation
27+
schema:
28+
additionalProperties: false
29+
properties:
30+
city:
31+
type: string
32+
country:
33+
type: string
34+
required:
35+
- city
36+
- country
37+
type: object
38+
strict: true
39+
type: json_schema
40+
stream: false
41+
uri: https://api.groq.com/openai/v1/chat/completions
42+
response:
43+
headers:
44+
alt-svc:
45+
- h3=":443"; ma=86400
46+
cache-control:
47+
- private, max-age=0, no-store, no-cache, must-revalidate
48+
connection:
49+
- keep-alive
50+
content-length:
51+
- '947'
52+
content-type:
53+
- application/json
54+
transfer-encoding:
55+
- chunked
56+
vary:
57+
- Origin
58+
parsed_body:
59+
choices:
60+
- finish_reason: stop
61+
index: 0
62+
logprobs: null
63+
message:
64+
content: '{"city":"Mexico City","country":"Mexico"}'
65+
reasoning: 'The user asks: "What is the largest city in Mexico?" The system expects a JSON object conforming to
66+
CityLocation schema: properties city (string) and country (string), required both. Provide largest city in Mexico:
67+
Mexico City. So output JSON: {"city":"Mexico City","country":"Mexico"} in compact format, no extra text.'
68+
role: assistant
69+
created: 1756843265
70+
id: chatcmpl-92437948-262c-49fe-87d1-774e54201105
71+
model: openai/gpt-oss-120b
72+
object: chat.completion
73+
service_tier: on_demand
74+
system_fingerprint: fp_213abb2467
75+
usage:
76+
completion_time: 0.186978247
77+
completion_tokens: 94
78+
prompt_time: 0.008149307
79+
prompt_tokens: 178
80+
queue_time: 1.095749809
81+
total_time: 0.195127554
82+
total_tokens: 272
83+
usage_breakdown: null
84+
x_groq:
85+
id: req_01k4609j8zemga5d9erfnwbfma
86+
status:
87+
code: 200
88+
message: OK
89+
version: 1
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- application/json
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '516'
12+
content-type:
13+
- application/json
14+
host:
15+
- api.groq.com
16+
method: POST
17+
parsed_body:
18+
messages:
19+
- content: |-
20+
Always respond with a JSON object that's compatible with this schema:
21+
22+
{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}
23+
24+
Don't include any text or Markdown fencing before or after.
25+
role: system
26+
- content: What is the largest city in Mexico?
27+
role: user
28+
model: openai/gpt-oss-120b
29+
n: 1
30+
response_format:
31+
type: json_object
32+
stream: false
33+
uri: https://api.groq.com/openai/v1/chat/completions
34+
response:
35+
headers:
36+
alt-svc:
37+
- h3=":443"; ma=86400
38+
cache-control:
39+
- private, max-age=0, no-store, no-cache, must-revalidate
40+
connection:
41+
- keep-alive
42+
content-length:
43+
- '926'
44+
content-type:
45+
- application/json
46+
transfer-encoding:
47+
- chunked
48+
vary:
49+
- Origin
50+
parsed_body:
51+
choices:
52+
- finish_reason: stop
53+
index: 0
54+
logprobs: null
55+
message:
56+
content: '{"city":"Mexico City","country":"Mexico"}'
57+
reasoning: 'We need to respond with JSON object with properties city and country. The question: "What is the largest
58+
city in Mexico?" The answer: City is Mexico City, country is Mexico. Must output compact JSON without any extra
59+
text or markdown. So {"city":"Mexico City","country":"Mexico"} Ensure valid JSON.'
60+
role: assistant
61+
created: 1756843266
62+
id: chatcmpl-d7085def-1e9f-45d7-b90b-65633ef23489
63+
model: openai/gpt-oss-120b
64+
object: chat.completion
65+
service_tier: on_demand
66+
system_fingerprint: fp_ed9190d8b7
67+
usage:
68+
completion_time: 0.173182068
69+
completion_tokens: 87
70+
prompt_time: 0.006958709
71+
prompt_tokens: 177
72+
queue_time: 0.212268627
73+
total_time: 0.180140777
74+
total_tokens: 264
75+
usage_breakdown: null
76+
x_groq:
77+
id: req_01k4609m4zf9qsbf0wwgvw8398
78+
status:
79+
code: 200
80+
message: OK
81+
version: 1

tests/models/test_groq.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pytest
1414
from dirty_equals import IsListOrTuple
1515
from inline_snapshot import snapshot
16+
from pydantic import BaseModel
1617
from typing_extensions import TypedDict
1718

1819
from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior
@@ -37,6 +38,7 @@
3738
ToolReturnPart,
3839
UserPromptPart,
3940
)
41+
from pydantic_ai.output import NativeOutput, PromptedOutput
4042
from pydantic_ai.usage import RequestUsage, RunUsage
4143

4244
from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, raise_if_exception, try_import
@@ -1004,3 +1006,90 @@ async def test_groq_model_thinking_part_iter(allow_model_requests: None, groq_ap
10041006
length=996,
10051007
)
10061008
)
1009+
1010+
1011+
async def test_groq_native_output(allow_model_requests: None, groq_api_key: str):
1012+
m = GroqModel('openai/gpt-oss-120b', provider=GroqProvider(api_key=groq_api_key))
1013+
1014+
class CityLocation(BaseModel):
1015+
"""A city and its country."""
1016+
1017+
city: str
1018+
country: str
1019+
1020+
agent = Agent(m, output_type=NativeOutput(CityLocation))
1021+
1022+
result = await agent.run('What is the largest city in Mexico?')
1023+
assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico'))
1024+
1025+
assert result.all_messages() == snapshot(
1026+
[
1027+
ModelRequest(
1028+
parts=[
1029+
UserPromptPart(
1030+
content='What is the largest city in Mexico?',
1031+
timestamp=IsDatetime(),
1032+
)
1033+
]
1034+
),
1035+
ModelResponse(
1036+
parts=[
1037+
ThinkingPart(
1038+
content='The user asks: "What is the largest city in Mexico?" The system expects a JSON object conforming to CityLocation schema: properties city (string) and country (string), required both. Provide largest city in Mexico: Mexico City. So output JSON: {"city":"Mexico City","country":"Mexico"} in compact format, no extra text.'
1039+
),
1040+
TextPart(content='{"city":"Mexico City","country":"Mexico"}'),
1041+
],
1042+
usage=RequestUsage(input_tokens=178, output_tokens=94),
1043+
model_name='openai/gpt-oss-120b',
1044+
timestamp=IsDatetime(),
1045+
provider_name='groq',
1046+
provider_response_id=IsStr(),
1047+
),
1048+
]
1049+
)
1050+
1051+
1052+
async def test_groq_prompted_output(allow_model_requests: None, groq_api_key: str):
1053+
m = GroqModel('openai/gpt-oss-120b', provider=GroqProvider(api_key=groq_api_key))
1054+
1055+
class CityLocation(BaseModel):
1056+
city: str
1057+
country: str
1058+
1059+
agent = Agent(m, output_type=PromptedOutput(CityLocation))
1060+
1061+
result = await agent.run('What is the largest city in Mexico?')
1062+
assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico'))
1063+
1064+
assert result.all_messages() == snapshot(
1065+
[
1066+
ModelRequest(
1067+
parts=[
1068+
UserPromptPart(
1069+
content='What is the largest city in Mexico?',
1070+
timestamp=IsDatetime(),
1071+
)
1072+
],
1073+
instructions="""\
1074+
Always respond with a JSON object that's compatible with this schema:
1075+
1076+
{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}
1077+
1078+
Don't include any text or Markdown fencing before or after.\
1079+
""",
1080+
),
1081+
ModelResponse(
1082+
parts=[
1083+
ThinkingPart(
1084+
content='We need to respond with JSON object with properties city and country. The question: "What is the largest city in Mexico?" The answer: City is Mexico City, country is Mexico. Must output compact JSON without any extra text or markdown. So {"city":"Mexico City","country":"Mexico"} Ensure valid JSON.'
1085+
),
1086+
TextPart(content='{"city":"Mexico City","country":"Mexico"}'),
1087+
],
1088+
usage=RequestUsage(input_tokens=177, output_tokens=87),
1089+
model_name='openai/gpt-oss-120b',
1090+
timestamp=IsDatetime(),
1091+
provider_name='groq',
1092+
provider_response_id=IsStr(),
1093+
),
1094+
]
1095+
)

0 commit comments

Comments
 (0)