Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@
)
from mistralai.models.assistantmessage import AssistantMessage as MistralAssistantMessage
from mistralai.models.function import Function as MistralFunction
from mistralai.models.prediction import (
Prediction as MistralPrediction,
PredictionTypedDict as MistralPredictionTypedDict,
)
from mistralai.models.systemmessage import SystemMessage as MistralSystemMessage
from mistralai.models.toolmessage import ToolMessage as MistralToolMessage
from mistralai.models.usermessage import UserMessage as MistralUserMessage
Expand Down Expand Up @@ -114,8 +118,12 @@ class MistralModelSettings(ModelSettings, total=False):
"""Settings used for a Mistral model request."""

# ALL FIELDS MUST BE `mistral_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
mistral_prediction: str | MistralPrediction | MistralPredictionTypedDict | None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we support only str? It looks like the types don't have any additional fields, and None is unnecessary as they key can just be omitted from the dict.

Also if you're up for updating the OpenAI equivalent to support str as well, that'd be great :)

"""Prediction content for the model to use as a prefix. See Predictive outputs.

# This class is a placeholder for any future mistral-specific settings
This feature is currently only supported for certain Mistral models. See the model cards at Models.
As of now, codestral-latest and mistral-large-2411 support [predicted outputs](https://docs.mistral.ai/capabilities/predicted_outputs).
"""


@dataclass(init=False)
Expand Down Expand Up @@ -241,6 +249,7 @@ async def _completions_create(
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
random_seed=model_settings.get('seed', UNSET),
stop=model_settings.get('stop_sequences', None),
prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)),
http_headers={'User-Agent': get_user_agent()},
)
except SDKError as e:
Expand Down Expand Up @@ -281,6 +290,7 @@ async def _stream_completions_create(
presence_penalty=model_settings.get('presence_penalty'),
frequency_penalty=model_settings.get('frequency_penalty'),
stop=model_settings.get('stop_sequences', None),
prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)),
http_headers={'User-Agent': get_user_agent()},
)

Expand All @@ -298,6 +308,7 @@ async def _stream_completions_create(
'type': 'json_object'
}, # TODO: Should be able to use json_schema now: https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/, https://github.com/mistralai/client-python/blob/bc4adf335968c8a272e1ab7da8461c9943d8e701/src/mistralai/extra/utils/response_format.py#L9
stream=True,
prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)),
http_headers={'User-Agent': get_user_agent()},
)

Expand All @@ -307,6 +318,7 @@ async def _stream_completions_create(
model=str(self._model_name),
messages=mistral_messages,
stream=True,
prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)),
http_headers={'User-Agent': get_user_agent()},
)
assert response, 'A unexpected empty response from Mistral.'
Expand Down Expand Up @@ -427,6 +439,24 @@ def _map_tool_call(t: ToolCallPart) -> MistralToolCall:
function=MistralFunctionCall(name=t.tool_name, arguments=t.args or {}),
)

@staticmethod
def _map_setting_prediction(
prediction: str | MistralPredictionTypedDict | MistralPrediction | None,
) -> MistralPrediction | None:
"""Maps various prediction input types to a MistralPrediction object."""
if not prediction:
return None
if isinstance(prediction, MistralPrediction):
return prediction
elif isinstance(prediction, str):
return MistralPrediction(content=prediction)
elif isinstance(prediction, dict):
return MistralPrediction.model_validate(prediction)
else:
raise RuntimeError(
f'Unsupported prediction type: {type(prediction)} for MistralModelSettings. Expected str, dict, or MistralPrediction.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the suggestion above this can be simplified a lot and we won't need this error anymore, but as a note for the future: we don't need errors like this, we can assume the user is type-checking their code.

)

def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage:
"""Get a message with an example of the expected output format."""
examples: list[dict[str, Any]] = []
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
interactions:
- request:
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '315'
content-type:
- application/json
host:
- api.mistral.ai
method: POST
parsed_body:
messages:
- content:
- text: Correct the math, keep everything else. No explanation, no formatting.
type: text
- text: The result of 21+21=99
type: text
role: user
model: mistral-large-2411
n: 1
prediction:
content: The result of 21+21=99
type: content
stream: false
top_p: 1.0
uri: https://api.mistral.ai/v1/chat/completions
response:
headers:
access-control-allow-origin:
- '*'
alt-svc:
- h3=":443"; ma=86400
connection:
- keep-alive
content-length:
- '319'
content-type:
- application/json
mistral-correlation-id:
- 019a63b7-40ba-70cb-94d0-84f036d7c76f
strict-transport-security:
- max-age=15552000; includeSubDomains; preload
transfer-encoding:
- chunked
parsed_body:
choices:
- finish_reason: stop
index: 0
message:
content: The result of 21+21=42
role: assistant
tool_calls: null
created: 1762609545
id: 6c36e8b6c3c145bd8ada32f9bd0f6be9
model: mistral-large-2411
object: chat.completion
usage:
completion_tokens: 13
prompt_tokens: 33
total_tokens: 46
status:
code: 200
message: OK
version: 1
66 changes: 66 additions & 0 deletions tests/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
SDKError,
ToolCall as MistralToolCall,
)
from mistralai.models.prediction import (
Prediction as MistralPrediction,
PredictionTypedDict as MistralPredictionTypedDict,
)
from mistralai.types.basemodel import Unset as MistralUnset

from pydantic_ai.models.mistral import MistralModel, MistralStreamedResponse
Expand Down Expand Up @@ -1677,6 +1681,51 @@ async def get_location(loc_name: str) -> str:
#####################
## Test methods
#####################
@pytest.fixture
def example_dict() -> MistralPredictionTypedDict:
"""Fixture providing a typed dict for prediction."""
return {'type': 'content', 'content': 'foo'}


@pytest.fixture
def example_prediction() -> MistralPrediction:
"""Fixture providing a MistralPrediction object."""
return MistralPrediction(content='bar')


@pytest.mark.parametrize(
'input_value,expected_content',
[
('plain text', 'plain text'),
('example_prediction', 'bar'),
('example_dict', 'foo'),
(None, None),
],
)
def test_map_setting_prediction_valid(request: pytest.FixtureRequest, input_value: str, expected_content: str | None):
"""
Accepted input types (str, dict, MistralPrediction, None) must be correctlyconverted to a MistralPrediction or None.
"""
# If the parameter is a fixture name, resolve it using request
resolved_value: str | MistralPredictionTypedDict | MistralPrediction | None = input_value
if isinstance(input_value, str) and input_value in {'example_dict', 'example_prediction'}:
resolved_value = request.getfixturevalue(input_value)

result = MistralModel._map_setting_prediction(resolved_value) # pyright: ignore[reportPrivateUsage]

if resolved_value is None:
assert result is None
else:
assert isinstance(result, MistralPrediction)
assert result.content == expected_content


def test_map_setting_prediction_unsupported_type():
"""Test that _map_setting_prediction raises RuntimeError for unsupported types."""
with pytest.raises(
RuntimeError, match='Unsupported prediction type.*int.*Expected str, dict, or MistralPrediction'
):
MistralModel._map_setting_prediction(123) # pyright: ignore[reportPrivateUsage, reportArgumentType]


def test_generate_user_output_format_complex(mistral_api_key: str):
Expand Down Expand Up @@ -2263,3 +2312,20 @@ async def test_mistral_model_thinking_part_iter(allow_model_requests: None, mist
),
]
)


@pytest.mark.vcr()
async def test_mistral_chat_with_prediction(allow_model_requests: None, mistral_api_key: str):
"""Test chat completion with prediction parameter using a math query."""
from pydantic_ai.models.mistral import MistralModelSettings

model = MistralModel('mistral-large-2411', provider=MistralProvider(api_key=mistral_api_key))
prediction = 'The result of 21+21=99'
settings = MistralModelSettings(mistral_prediction=prediction)
agent = Agent(model, model_settings=settings)

result = await agent.run(['Correct the math, keep everything else. No explanation, no formatting.', prediction])

# Verify that the response uses the expected prediction
assert '42' in result.output
assert 'The result of 21+21=' in result.output
Loading