Skip to content

Commit 71e3b82

Browse files
committed
Update tests and model documentation for revised Mistral prediction handling.
1 parent 52420be commit 71e3b82

File tree

3 files changed

+35
-34
lines changed

3 files changed

+35
-34
lines changed

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,7 @@ class MistralModelSettings(ModelSettings, total=False):
122122
"""Prediction content for the model to use as a prefix. See Predictive outputs.
123123
124124
This feature is currently only supported for certain Mistral models. See the model cards at Models.
125-
For example, it is supported for the latest Mistral Serie Large (> 2), Medium (> 3), Small (> 3) and Pixtral models,
126-
but not for reasoning or coding models yet.
125+
As of now, codestral-latest and mistral-large-2411 support [predicted outputs](https://docs.mistral.ai/capabilities/predicted_outputs).
127126
"""
128127

129128

tests/models/cassettes/test_mistral/test_mistral_chat_with_prediction.yaml

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ interactions:
88
connection:
99
- keep-alive
1010
content-length:
11-
- '246'
11+
- '315'
1212
content-type:
1313
- application/json
1414
host:
@@ -17,13 +17,16 @@ interactions:
1717
parsed_body:
1818
messages:
1919
- content:
20-
- text: Correct only the math, respond with no explanation, no formatting.
20+
- text: Correct the math, keep everything else. No explanation, no formatting.
2121
type: text
2222
- text: The result of 21+21=99
2323
type: text
2424
role: user
25-
model: mistral-small-latest
25+
model: mistral-large-2411
2626
n: 1
27+
prediction:
28+
content: The result of 21+21=99
29+
type: content
2730
stream: false
2831
top_p: 1.0
2932
uri: https://api.mistral.ai/v1/chat/completions
@@ -36,11 +39,11 @@ interactions:
3639
connection:
3740
- keep-alive
3841
content-length:
39-
- '321'
42+
- '319'
4043
content-type:
4144
- application/json
4245
mistral-correlation-id:
43-
- 019a639b-7bf4-7481-96af-078cd1a7d277
46+
- 019a63b7-40ba-70cb-94d0-84f036d7c76f
4447
strict-transport-security:
4548
- max-age=15552000; includeSubDomains; preload
4649
transfer-encoding:
@@ -53,14 +56,14 @@ interactions:
5356
content: The result of 21+21=42
5457
role: assistant
5558
tool_calls: null
56-
created: 1762607725
57-
id: a7952046ef794d1697627b54231df17a
58-
model: mistral-small-latest
59+
created: 1762609545
60+
id: 6c36e8b6c3c145bd8ada32f9bd0f6be9
61+
model: mistral-large-2411
5962
object: chat.completion
6063
usage:
6164
completion_tokens: 13
62-
prompt_tokens: 28
63-
total_tokens: 41
65+
prompt_tokens: 33
66+
total_tokens: 46
6467
status:
6568
code: 200
6669
message: OK

tests/models/test_mistral.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@
5353
SDKError,
5454
ToolCall as MistralToolCall,
5555
)
56-
from mistralai.types.basemodel import Unset as MistralUnset
5756
from mistralai.models.prediction import (
5857
Prediction as MistralPrediction,
5958
PredictionTypedDict as MistralPredictionTypedDict,
6059
)
60+
from mistralai.types.basemodel import Unset as MistralUnset
6161

6262
from pydantic_ai.models.mistral import MistralModel, MistralStreamedResponse
6363
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
@@ -1681,44 +1681,45 @@ async def get_location(loc_name: str) -> str:
16811681
#####################
16821682
## Test methods
16831683
#####################
1684-
# --- _map_setting_prediction --------------------------------------------------
16851684
@pytest.fixture
16861685
def example_dict() -> MistralPredictionTypedDict:
16871686
"""Fixture providing a typed dict for prediction."""
1688-
return {"type": "content", "content": "foo"}
1687+
return {'type': 'content', 'content': 'foo'}
16891688

16901689

16911690
@pytest.fixture
16921691
def example_prediction() -> MistralPrediction:
16931692
"""Fixture providing a MistralPrediction object."""
1694-
return MistralPrediction(content="bar")
1693+
return MistralPrediction(content='bar')
16951694

16961695

16971696
@pytest.mark.parametrize(
1698-
"input_value,expected_content",
1697+
'input_value,expected_content',
16991698
[
1700-
("plain text", "plain text"),
1701-
("example_prediction", "bar"),
1702-
("example_dict", "foo"),
1699+
('plain text', 'plain text'),
1700+
('example_prediction', 'bar'),
1701+
('example_dict', 'foo'),
17031702
(None, None),
17041703
],
17051704
)
1706-
def test_map_setting_prediction_valid(request, input_value, expected_content):
1705+
def test_map_setting_prediction_valid(request: pytest.FixtureRequest, input_value: str, expected_content: str | None):
17071706
"""
17081707
Accepted input types (str, dict, MistralPrediction, None) must be correctlyconverted to a MistralPrediction or None.
17091708
"""
17101709
# If the parameter is a fixture name, resolve it using request
1711-
if isinstance(input_value, str) and input_value in {"example_dict", "example_prediction"}:
1712-
input_value = request.getfixturevalue(input_value)
1710+
resolved_value: str | MistralPredictionTypedDict | MistralPrediction | None = input_value
1711+
if isinstance(input_value, str) and input_value in {'example_dict', 'example_prediction'}:
1712+
resolved_value = request.getfixturevalue(input_value)
17131713

1714-
result = MistralModel._map_setting_prediction(input_value) # pyright: ignore[reportPrivateUsage]
1714+
result = MistralModel._map_setting_prediction(resolved_value) # pyright: ignore[reportPrivateUsage]
17151715

1716-
if input_value is None:
1716+
if resolved_value is None:
17171717
assert result is None
17181718
else:
17191719
assert isinstance(result, MistralPrediction)
17201720
assert result.content == expected_content
1721-
# -----------------------------------------------------
1721+
1722+
17221723
def test_generate_user_output_format_complex(mistral_api_key: str):
17231724
"""
17241725
Single test that includes properties exercising every branch
@@ -2304,21 +2305,19 @@ async def test_mistral_model_thinking_part_iter(allow_model_requests: None, mist
23042305
]
23052306
)
23062307

2308+
23072309
@pytest.mark.vcr()
23082310
async def test_mistral_chat_with_prediction(allow_model_requests: None, mistral_api_key: str):
23092311
"""Test chat completion with prediction parameter using a math query."""
23102312
from pydantic_ai.models.mistral import MistralModelSettings
23112313

2312-
model = MistralModel(
2313-
'mistral-small-latest',
2314-
provider=MistralProvider(api_key=mistral_api_key)
2315-
)
2316-
prediction = "The result of 21+21=99"
2317-
settings = MistralModelSettings(prediction=prediction)
2314+
model = MistralModel('mistral-large-2411', provider=MistralProvider(api_key=mistral_api_key))
2315+
prediction = 'The result of 21+21=99'
2316+
settings = MistralModelSettings(mistral_prediction=prediction)
23182317
agent = Agent(model, model_settings=settings)
23192318

2320-
result = await agent.run(['Correct only the math, respond with no explanation, no formatting.',"The result of 21+21=99"])
2319+
result = await agent.run(['Correct the math, keep everything else. No explanation, no formatting.', prediction])
23212320

23222321
# Verify that the response uses the expected prediction
2323-
assert 'The result of 21+21=' in result.output
23242322
assert '42' in result.output
2323+
assert 'The result of 21+21=' in result.output

0 commit comments

Comments
 (0)