7777 )
7878 from mistralai .models .assistantmessage import AssistantMessage as MistralAssistantMessage
7979 from mistralai .models .function import Function as MistralFunction
80+ from mistralai .models .prediction import (
81+ Prediction as MistralPrediction ,
82+ PredictionTypedDict as MistralPredictionTypedDict ,
83+ )
8084 from mistralai .models .systemmessage import SystemMessage as MistralSystemMessage
8185 from mistralai .models .toolmessage import ToolMessage as MistralToolMessage
8286 from mistralai .models .usermessage import UserMessage as MistralUserMessage
@@ -114,8 +118,13 @@ class MistralModelSettings(ModelSettings, total=False):
114118 """Settings used for a Mistral model request."""
115119
116120 # ALL FIELDS MUST BE `mistral_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
121+ mistral_prediction : str | MistralPrediction | MistralPredictionTypedDict | None
122+ """Prediction content for the model to use as a prefix. See Predictive outputs.
117123
118- # This class is a placeholder for any future mistral-specific settings
124+ This feature is currently only supported for certain Mistral models. See the model cards at Models.
125+ For example, it is supported for the latest Mistral Serie Large (> 2), Medium (> 3), Small (> 3) and Pixtral models,
126+ but not for reasoning or coding models yet.
127+ """
119128
120129
121130@dataclass (init = False )
@@ -241,6 +250,7 @@ async def _completions_create(
241250 timeout_ms = self ._get_timeout_ms (model_settings .get ('timeout' )),
242251 random_seed = model_settings .get ('seed' , UNSET ),
243252 stop = model_settings .get ('stop_sequences' , None ),
253+ prediction = self ._map_setting_prediction (model_settings .get ('mistral_prediction' , None )),
244254 http_headers = {'User-Agent' : get_user_agent ()},
245255 )
246256 except SDKError as e :
@@ -281,6 +291,7 @@ async def _stream_completions_create(
281291 presence_penalty = model_settings .get ('presence_penalty' ),
282292 frequency_penalty = model_settings .get ('frequency_penalty' ),
283293 stop = model_settings .get ('stop_sequences' , None ),
294+ prediction = self ._map_setting_prediction (model_settings .get ('mistral_prediction' , None )),
284295 http_headers = {'User-Agent' : get_user_agent ()},
285296 )
286297
@@ -298,6 +309,7 @@ async def _stream_completions_create(
298309 'type' : 'json_object'
299310 }, # TODO: Should be able to use json_schema now: https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/, https://github.com/mistralai/client-python/blob/bc4adf335968c8a272e1ab7da8461c9943d8e701/src/mistralai/extra/utils/response_format.py#L9
300311 stream = True ,
312+ prediction = self ._map_setting_prediction (model_settings .get ('mistral_prediction' , None )),
301313 http_headers = {'User-Agent' : get_user_agent ()},
302314 )
303315
@@ -307,6 +319,7 @@ async def _stream_completions_create(
307319 model = str (self ._model_name ),
308320 messages = mistral_messages ,
309321 stream = True ,
322+ prediction = self ._map_setting_prediction (model_settings .get ('mistral_prediction' , None )),
310323 http_headers = {'User-Agent' : get_user_agent ()},
311324 )
312325 assert response , 'A unexpected empty response from Mistral.'
@@ -427,6 +440,24 @@ def _map_tool_call(t: ToolCallPart) -> MistralToolCall:
427440 function = MistralFunctionCall (name = t .tool_name , arguments = t .args or {}),
428441 )
429442
443+ @staticmethod
444+ def _map_setting_prediction (
445+ prediction : str | MistralPredictionTypedDict | MistralPrediction | None ,
446+ ) -> MistralPrediction | None :
447+ """Maps various prediction input types to a MistralPrediction object."""
448+ if not prediction :
449+ return None
450+ if isinstance (prediction , MistralPrediction ):
451+ return prediction
452+ elif isinstance (prediction , str ):
453+ return MistralPrediction (content = prediction )
454+ elif isinstance (prediction , dict ):
455+ return MistralPrediction .model_validate (prediction )
456+ else :
457+ raise RuntimeError (
458+ f'Unsupported prediction type: { type (prediction )} for MistralModelSettings. Expected str, dict, or MistralPrediction.'
459+ )
460+
430461 def _generate_user_output_format (self , schemas : list [dict [str , Any ]]) -> MistralUserMessage :
431462 """Get a message with an example of the expected output format."""
432463 examples : list [dict [str , Any ]] = []
0 commit comments