Skip to content

Commit 0337c5a

Browse files
committed
Merge branch 'main' into groq-tool-use-failed
# Conflicts: # tests/models/test_groq.py
2 parents a39582c + 55ad5bd commit 0337c5a

File tree

7 files changed

+341
-3
lines changed

7 files changed

+341
-3
lines changed

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from pydantic import BaseModel, Json, ValidationError
1111
from typing_extensions import assert_never
1212

13+
from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
14+
1315
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
1416
from .._run_context import RunContext
1517
from .._thinking_part import split_content_into_text_and_thinking
@@ -242,6 +244,18 @@ async def _completions_create(
242244

243245
groq_messages = self._map_messages(messages)
244246

247+
response_format: chat.completion_create_params.ResponseFormat | None = None
248+
if model_request_parameters.output_mode == 'native':
249+
output_object = model_request_parameters.output_object
250+
assert output_object is not None
251+
response_format = self._map_json_schema(output_object)
252+
elif (
253+
model_request_parameters.output_mode == 'prompted'
254+
and not tools
255+
and self.profile.supports_json_object_output
256+
): # pragma: no branch
257+
response_format = {'type': 'json_object'}
258+
245259
try:
246260
extra_headers = model_settings.get('extra_headers', {})
247261
extra_headers.setdefault('User-Agent', get_user_agent())
@@ -254,6 +268,7 @@ async def _completions_create(
254268
tool_choice=tool_choice or NOT_GIVEN,
255269
stop=model_settings.get('stop_sequences', NOT_GIVEN),
256270
stream=stream,
271+
response_format=response_format or NOT_GIVEN,
257272
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
258273
temperature=model_settings.get('temperature', NOT_GIVEN),
259274
top_p=model_settings.get('top_p', NOT_GIVEN),
@@ -399,6 +414,19 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
399414
},
400415
}
401416

417+
def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
418+
response_format_param: chat.completion_create_params.ResponseFormatResponseFormatJsonSchema = {
419+
'type': 'json_schema',
420+
'json_schema': {
421+
'name': o.name or DEFAULT_OUTPUT_TOOL_NAME,
422+
'schema': o.json_schema,
423+
'strict': o.strict,
424+
},
425+
}
426+
if o.description: # pragma: no branch
427+
response_format_param['json_schema']['description'] = o.description
428+
return response_format_param
429+
402430
@classmethod
403431
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
404432
for part in message.parts:

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def _map_tool_call(t: ToolCallPart) -> ChatCompletionMessageFunctionToolCallPara
606606
def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
607607
response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage]
608608
'type': 'json_schema',
609-
'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema, 'strict': True},
609+
'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema},
610610
}
611611
if o.description:
612612
response_format_param['json_schema']['description'] = o.description

pydantic_ai_slim/pydantic_ai/providers/groq.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pydantic_ai.profiles.meta import meta_model_profile
1515
from pydantic_ai.profiles.mistral import mistral_model_profile
1616
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
17+
from pydantic_ai.profiles.openai import openai_model_profile
1718
from pydantic_ai.profiles.qwen import qwen_model_profile
1819
from pydantic_ai.providers import Provider
1920

@@ -26,6 +27,23 @@
2627
) from _import_error
2728

2829

30+
def groq_moonshotai_model_profile(model_name: str) -> ModelProfile | None:
31+
"""Get the model profile for an MoonshotAI model used with the Groq provider."""
32+
return ModelProfile(supports_json_object_output=True, supports_json_schema_output=True).update(
33+
moonshotai_model_profile(model_name)
34+
)
35+
36+
37+
def meta_groq_model_profile(model_name: str) -> ModelProfile | None:
38+
"""Get the model profile for a Meta model used with the Groq provider."""
39+
if model_name in {'llama-4-maverick-17b-128e-instruct', 'llama-4-scout-17b-16e-instruct'}:
40+
return ModelProfile(supports_json_object_output=True, supports_json_schema_output=True).update(
41+
meta_model_profile(model_name)
42+
)
43+
else:
44+
return meta_model_profile(model_name)
45+
46+
2947
class GroqProvider(Provider[AsyncGroq]):
3048
"""Provider for Groq API."""
3149

@@ -44,13 +62,14 @@ def client(self) -> AsyncGroq:
4462
def model_profile(self, model_name: str) -> ModelProfile | None:
4563
prefix_to_profile = {
4664
'llama': meta_model_profile,
47-
'meta-llama/': meta_model_profile,
65+
'meta-llama/': meta_groq_model_profile,
4866
'gemma': google_model_profile,
4967
'qwen': qwen_model_profile,
5068
'deepseek': deepseek_model_profile,
5169
'mistral': mistral_model_profile,
52-
'moonshotai/': moonshotai_model_profile,
70+
'moonshotai/': groq_moonshotai_model_profile,
5371
'compound-': groq_model_profile,
72+
'openai/': openai_model_profile,
5473
}
5574

5675
for prefix, profile_func in prefix_to_profile.items():
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- application/json
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '416'
12+
content-type:
13+
- application/json
14+
host:
15+
- api.groq.com
16+
method: POST
17+
parsed_body:
18+
messages:
19+
- content: What is the largest city in Mexico?
20+
role: user
21+
model: openai/gpt-oss-120b
22+
n: 1
23+
response_format:
24+
json_schema:
25+
description: A city and its country.
26+
name: CityLocation
27+
schema:
28+
additionalProperties: false
29+
properties:
30+
city:
31+
type: string
32+
country:
33+
type: string
34+
required:
35+
- city
36+
- country
37+
type: object
38+
strict: true
39+
type: json_schema
40+
stream: false
41+
uri: https://api.groq.com/openai/v1/chat/completions
42+
response:
43+
headers:
44+
alt-svc:
45+
- h3=":443"; ma=86400
46+
cache-control:
47+
- private, max-age=0, no-store, no-cache, must-revalidate
48+
connection:
49+
- keep-alive
50+
content-length:
51+
- '947'
52+
content-type:
53+
- application/json
54+
transfer-encoding:
55+
- chunked
56+
vary:
57+
- Origin
58+
parsed_body:
59+
choices:
60+
- finish_reason: stop
61+
index: 0
62+
logprobs: null
63+
message:
64+
content: '{"city":"Mexico City","country":"Mexico"}'
65+
reasoning: 'The user asks: "What is the largest city in Mexico?" The system expects a JSON object conforming to
66+
CityLocation schema: properties city (string) and country (string), required both. Provide largest city in Mexico:
67+
Mexico City. So output JSON: {"city":"Mexico City","country":"Mexico"} in compact format, no extra text.'
68+
role: assistant
69+
created: 1756843265
70+
id: chatcmpl-92437948-262c-49fe-87d1-774e54201105
71+
model: openai/gpt-oss-120b
72+
object: chat.completion
73+
service_tier: on_demand
74+
system_fingerprint: fp_213abb2467
75+
usage:
76+
completion_time: 0.186978247
77+
completion_tokens: 94
78+
prompt_time: 0.008149307
79+
prompt_tokens: 178
80+
queue_time: 1.095749809
81+
total_time: 0.195127554
82+
total_tokens: 272
83+
usage_breakdown: null
84+
x_groq:
85+
id: req_01k4609j8zemga5d9erfnwbfma
86+
status:
87+
code: 200
88+
message: OK
89+
version: 1
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- application/json
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '516'
12+
content-type:
13+
- application/json
14+
host:
15+
- api.groq.com
16+
method: POST
17+
parsed_body:
18+
messages:
19+
- content: |-
20+
Always respond with a JSON object that's compatible with this schema:
21+
22+
{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}
23+
24+
Don't include any text or Markdown fencing before or after.
25+
role: system
26+
- content: What is the largest city in Mexico?
27+
role: user
28+
model: openai/gpt-oss-120b
29+
n: 1
30+
response_format:
31+
type: json_object
32+
stream: false
33+
uri: https://api.groq.com/openai/v1/chat/completions
34+
response:
35+
headers:
36+
alt-svc:
37+
- h3=":443"; ma=86400
38+
cache-control:
39+
- private, max-age=0, no-store, no-cache, must-revalidate
40+
connection:
41+
- keep-alive
42+
content-length:
43+
- '926'
44+
content-type:
45+
- application/json
46+
transfer-encoding:
47+
- chunked
48+
vary:
49+
- Origin
50+
parsed_body:
51+
choices:
52+
- finish_reason: stop
53+
index: 0
54+
logprobs: null
55+
message:
56+
content: '{"city":"Mexico City","country":"Mexico"}'
57+
reasoning: 'We need to respond with JSON object with properties city and country. The question: "What is the largest
58+
city in Mexico?" The answer: City is Mexico City, country is Mexico. Must output compact JSON without any extra
59+
text or markdown. So {"city":"Mexico City","country":"Mexico"} Ensure valid JSON.'
60+
role: assistant
61+
created: 1756843266
62+
id: chatcmpl-d7085def-1e9f-45d7-b90b-65633ef23489
63+
model: openai/gpt-oss-120b
64+
object: chat.completion
65+
service_tier: on_demand
66+
system_fingerprint: fp_ed9190d8b7
67+
usage:
68+
completion_time: 0.173182068
69+
completion_tokens: 87
70+
prompt_time: 0.006958709
71+
prompt_tokens: 177
72+
queue_time: 0.212268627
73+
total_time: 0.180140777
74+
total_tokens: 264
75+
usage_breakdown: null
76+
x_groq:
77+
id: req_01k4609m4zf9qsbf0wwgvw8398
78+
status:
79+
code: 200
80+
message: OK
81+
version: 1

tests/models/test_groq.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pytest
1414
from dirty_equals import IsListOrTuple
1515
from inline_snapshot import snapshot
16+
from pydantic import BaseModel
1617
from typing_extensions import TypedDict
1718

1819
from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior
@@ -37,6 +38,7 @@
3738
ToolReturnPart,
3839
UserPromptPart,
3940
)
41+
from pydantic_ai.output import NativeOutput, PromptedOutput
4042
from pydantic_ai.usage import RequestUsage, RunUsage
4143

4244
from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, raise_if_exception, try_import
@@ -1217,3 +1219,90 @@ async def test_tool_regular_error(allow_model_requests: None, groq_api_key: str)
12171219
ModelHTTPError, match='The model `non-existent` does not exist or you do not have access to it.'
12181220
):
12191221
await agent.run('hello')
1222+
1223+
1224+
async def test_groq_native_output(allow_model_requests: None, groq_api_key: str):
1225+
m = GroqModel('openai/gpt-oss-120b', provider=GroqProvider(api_key=groq_api_key))
1226+
1227+
class CityLocation(BaseModel):
1228+
"""A city and its country."""
1229+
1230+
city: str
1231+
country: str
1232+
1233+
agent = Agent(m, output_type=NativeOutput(CityLocation))
1234+
1235+
result = await agent.run('What is the largest city in Mexico?')
1236+
assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico'))
1237+
1238+
assert result.all_messages() == snapshot(
1239+
[
1240+
ModelRequest(
1241+
parts=[
1242+
UserPromptPart(
1243+
content='What is the largest city in Mexico?',
1244+
timestamp=IsDatetime(),
1245+
)
1246+
]
1247+
),
1248+
ModelResponse(
1249+
parts=[
1250+
ThinkingPart(
1251+
content='The user asks: "What is the largest city in Mexico?" The system expects a JSON object conforming to CityLocation schema: properties city (string) and country (string), required both. Provide largest city in Mexico: Mexico City. So output JSON: {"city":"Mexico City","country":"Mexico"} in compact format, no extra text.'
1252+
),
1253+
TextPart(content='{"city":"Mexico City","country":"Mexico"}'),
1254+
],
1255+
usage=RequestUsage(input_tokens=178, output_tokens=94),
1256+
model_name='openai/gpt-oss-120b',
1257+
timestamp=IsDatetime(),
1258+
provider_name='groq',
1259+
provider_response_id=IsStr(),
1260+
),
1261+
]
1262+
)
1263+
1264+
1265+
async def test_groq_prompted_output(allow_model_requests: None, groq_api_key: str):
1266+
m = GroqModel('openai/gpt-oss-120b', provider=GroqProvider(api_key=groq_api_key))
1267+
1268+
class CityLocation(BaseModel):
1269+
city: str
1270+
country: str
1271+
1272+
agent = Agent(m, output_type=PromptedOutput(CityLocation))
1273+
1274+
result = await agent.run('What is the largest city in Mexico?')
1275+
assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico'))
1276+
1277+
assert result.all_messages() == snapshot(
1278+
[
1279+
ModelRequest(
1280+
parts=[
1281+
UserPromptPart(
1282+
content='What is the largest city in Mexico?',
1283+
timestamp=IsDatetime(),
1284+
)
1285+
],
1286+
instructions="""\
1287+
Always respond with a JSON object that's compatible with this schema:
1288+
1289+
{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}
1290+
1291+
Don't include any text or Markdown fencing before or after.\
1292+
""",
1293+
),
1294+
ModelResponse(
1295+
parts=[
1296+
ThinkingPart(
1297+
content='We need to respond with JSON object with properties city and country. The question: "What is the largest city in Mexico?" The answer: City is Mexico City, country is Mexico. Must output compact JSON without any extra text or markdown. So {"city":"Mexico City","country":"Mexico"} Ensure valid JSON.'
1298+
),
1299+
TextPart(content='{"city":"Mexico City","country":"Mexico"}'),
1300+
],
1301+
usage=RequestUsage(input_tokens=177, output_tokens=87),
1302+
model_name='openai/gpt-oss-120b',
1303+
timestamp=IsDatetime(),
1304+
provider_name='groq',
1305+
provider_response_id=IsStr(),
1306+
),
1307+
]
1308+
)

0 commit comments

Comments
 (0)