Skip to content

Commit 52420be

Browse files
committed
Add prediction support for MistralModel and associated tests.
1 parent 86b645f commit 52420be

File tree

3 files changed

+158
-1
lines changed

3 files changed

+158
-1
lines changed

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@
7777
)
7878
from mistralai.models.assistantmessage import AssistantMessage as MistralAssistantMessage
7979
from mistralai.models.function import Function as MistralFunction
80+
from mistralai.models.prediction import (
81+
Prediction as MistralPrediction,
82+
PredictionTypedDict as MistralPredictionTypedDict,
83+
)
8084
from mistralai.models.systemmessage import SystemMessage as MistralSystemMessage
8185
from mistralai.models.toolmessage import ToolMessage as MistralToolMessage
8286
from mistralai.models.usermessage import UserMessage as MistralUserMessage
@@ -114,8 +118,13 @@ class MistralModelSettings(ModelSettings, total=False):
114118
"""Settings used for a Mistral model request."""
115119

116120
# ALL FIELDS MUST BE `mistral_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
121+
mistral_prediction: str | MistralPrediction | MistralPredictionTypedDict | None
122+
"""Prediction content for the model to use as a prefix. See Predictive outputs.
117123
118-
# This class is a placeholder for any future mistral-specific settings
124+
This feature is currently only supported for certain Mistral models. See the model cards at Models.
125+
For example, it is supported for the latest Mistral Serie Large (> 2), Medium (> 3), Small (> 3) and Pixtral models,
126+
but not for reasoning or coding models yet.
127+
"""
119128

120129

121130
@dataclass(init=False)
@@ -241,6 +250,7 @@ async def _completions_create(
241250
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
242251
random_seed=model_settings.get('seed', UNSET),
243252
stop=model_settings.get('stop_sequences', None),
253+
prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)),
244254
http_headers={'User-Agent': get_user_agent()},
245255
)
246256
except SDKError as e:
@@ -281,6 +291,7 @@ async def _stream_completions_create(
281291
presence_penalty=model_settings.get('presence_penalty'),
282292
frequency_penalty=model_settings.get('frequency_penalty'),
283293
stop=model_settings.get('stop_sequences', None),
294+
prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)),
284295
http_headers={'User-Agent': get_user_agent()},
285296
)
286297

@@ -298,6 +309,7 @@ async def _stream_completions_create(
298309
'type': 'json_object'
299310
}, # TODO: Should be able to use json_schema now: https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/, https://github.com/mistralai/client-python/blob/bc4adf335968c8a272e1ab7da8461c9943d8e701/src/mistralai/extra/utils/response_format.py#L9
300311
stream=True,
312+
prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)),
301313
http_headers={'User-Agent': get_user_agent()},
302314
)
303315

@@ -307,6 +319,7 @@ async def _stream_completions_create(
307319
model=str(self._model_name),
308320
messages=mistral_messages,
309321
stream=True,
322+
prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)),
310323
http_headers={'User-Agent': get_user_agent()},
311324
)
312325
assert response, 'A unexpected empty response from Mistral.'
@@ -427,6 +440,24 @@ def _map_tool_call(t: ToolCallPart) -> MistralToolCall:
427440
function=MistralFunctionCall(name=t.tool_name, arguments=t.args or {}),
428441
)
429442

443+
@staticmethod
444+
def _map_setting_prediction(
445+
prediction: str | MistralPredictionTypedDict | MistralPrediction | None,
446+
) -> MistralPrediction | None:
447+
"""Maps various prediction input types to a MistralPrediction object."""
448+
if not prediction:
449+
return None
450+
if isinstance(prediction, MistralPrediction):
451+
return prediction
452+
elif isinstance(prediction, str):
453+
return MistralPrediction(content=prediction)
454+
elif isinstance(prediction, dict):
455+
return MistralPrediction.model_validate(prediction)
456+
else:
457+
raise RuntimeError(
458+
f'Unsupported prediction type: {type(prediction)} for MistralModelSettings. Expected str, dict, or MistralPrediction.'
459+
)
460+
430461
def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage:
431462
"""Get a message with an example of the expected output format."""
432463
examples: list[dict[str, Any]] = []
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- application/json
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '246'
12+
content-type:
13+
- application/json
14+
host:
15+
- api.mistral.ai
16+
method: POST
17+
parsed_body:
18+
messages:
19+
- content:
20+
- text: Correct only the math, respond with no explanation, no formatting.
21+
type: text
22+
- text: The result of 21+21=99
23+
type: text
24+
role: user
25+
model: mistral-small-latest
26+
n: 1
27+
stream: false
28+
top_p: 1.0
29+
uri: https://api.mistral.ai/v1/chat/completions
30+
response:
31+
headers:
32+
access-control-allow-origin:
33+
- '*'
34+
alt-svc:
35+
- h3=":443"; ma=86400
36+
connection:
37+
- keep-alive
38+
content-length:
39+
- '321'
40+
content-type:
41+
- application/json
42+
mistral-correlation-id:
43+
- 019a639b-7bf4-7481-96af-078cd1a7d277
44+
strict-transport-security:
45+
- max-age=15552000; includeSubDomains; preload
46+
transfer-encoding:
47+
- chunked
48+
parsed_body:
49+
choices:
50+
- finish_reason: stop
51+
index: 0
52+
message:
53+
content: The result of 21+21=42
54+
role: assistant
55+
tool_calls: null
56+
created: 1762607725
57+
id: a7952046ef794d1697627b54231df17a
58+
model: mistral-small-latest
59+
object: chat.completion
60+
usage:
61+
completion_tokens: 13
62+
prompt_tokens: 28
63+
total_tokens: 41
64+
status:
65+
code: 200
66+
message: OK
67+
version: 1

tests/models/test_mistral.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454
ToolCall as MistralToolCall,
5555
)
5656
from mistralai.types.basemodel import Unset as MistralUnset
57+
from mistralai.models.prediction import (
58+
Prediction as MistralPrediction,
59+
PredictionTypedDict as MistralPredictionTypedDict,
60+
)
5761

5862
from pydantic_ai.models.mistral import MistralModel, MistralStreamedResponse
5963
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
@@ -1677,8 +1681,44 @@ async def get_location(loc_name: str) -> str:
16771681
#####################
16781682
## Test methods
16791683
#####################
1684+
# --- _map_setting_prediction --------------------------------------------------
1685+
@pytest.fixture
1686+
def example_dict() -> MistralPredictionTypedDict:
1687+
"""Fixture providing a typed dict for prediction."""
1688+
return {"type": "content", "content": "foo"}
1689+
16801690

1691+
@pytest.fixture
1692+
def example_prediction() -> MistralPrediction:
1693+
"""Fixture providing a MistralPrediction object."""
1694+
return MistralPrediction(content="bar")
16811695

1696+
1697+
@pytest.mark.parametrize(
1698+
"input_value,expected_content",
1699+
[
1700+
("plain text", "plain text"),
1701+
("example_prediction", "bar"),
1702+
("example_dict", "foo"),
1703+
(None, None),
1704+
],
1705+
)
1706+
def test_map_setting_prediction_valid(request, input_value, expected_content):
1707+
"""
1708+
Accepted input types (str, dict, MistralPrediction, None) must be correctlyconverted to a MistralPrediction or None.
1709+
"""
1710+
# If the parameter is a fixture name, resolve it using request
1711+
if isinstance(input_value, str) and input_value in {"example_dict", "example_prediction"}:
1712+
input_value = request.getfixturevalue(input_value)
1713+
1714+
result = MistralModel._map_setting_prediction(input_value) # pyright: ignore[reportPrivateUsage]
1715+
1716+
if input_value is None:
1717+
assert result is None
1718+
else:
1719+
assert isinstance(result, MistralPrediction)
1720+
assert result.content == expected_content
1721+
# -----------------------------------------------------
16821722
def test_generate_user_output_format_complex(mistral_api_key: str):
16831723
"""
16841724
Single test that includes properties exercising every branch
@@ -2263,3 +2303,22 @@ async def test_mistral_model_thinking_part_iter(allow_model_requests: None, mist
22632303
),
22642304
]
22652305
)
2306+
2307+
@pytest.mark.vcr()
2308+
async def test_mistral_chat_with_prediction(allow_model_requests: None, mistral_api_key: str):
2309+
"""Test chat completion with prediction parameter using a math query."""
2310+
from pydantic_ai.models.mistral import MistralModelSettings
2311+
2312+
model = MistralModel(
2313+
'mistral-small-latest',
2314+
provider=MistralProvider(api_key=mistral_api_key)
2315+
)
2316+
prediction = "The result of 21+21=99"
2317+
settings = MistralModelSettings(prediction=prediction)
2318+
agent = Agent(model, model_settings=settings)
2319+
2320+
result = await agent.run(['Correct only the math, respond with no explanation, no formatting.',"The result of 21+21=99"])
2321+
2322+
# Verify that the response uses the expected prediction
2323+
assert 'The result of 21+21=' in result.output
2324+
assert '42' in result.output

0 commit comments

Comments
 (0)