|
53 | 53 | SDKError, |
54 | 54 | ToolCall as MistralToolCall, |
55 | 55 | ) |
56 | | - from mistralai.types.basemodel import Unset as MistralUnset |
57 | 56 | from mistralai.models.prediction import ( |
58 | 57 | Prediction as MistralPrediction, |
59 | 58 | PredictionTypedDict as MistralPredictionTypedDict, |
60 | 59 | ) |
| 60 | + from mistralai.types.basemodel import Unset as MistralUnset |
61 | 61 |
|
62 | 62 | from pydantic_ai.models.mistral import MistralModel, MistralStreamedResponse |
63 | 63 | from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings |
@@ -1681,44 +1681,45 @@ async def get_location(loc_name: str) -> str: |
1681 | 1681 | ##################### |
1682 | 1682 | ## Test methods |
1683 | 1683 | ##################### |
1684 | | -# --- _map_setting_prediction -------------------------------------------------- |
1685 | 1684 | @pytest.fixture |
1686 | 1685 | def example_dict() -> MistralPredictionTypedDict: |
1687 | 1686 | """Fixture providing a typed dict for prediction.""" |
1688 | | - return {"type": "content", "content": "foo"} |
| 1687 | + return {'type': 'content', 'content': 'foo'} |
1689 | 1688 |
|
1690 | 1689 |
|
1691 | 1690 | @pytest.fixture |
1692 | 1691 | def example_prediction() -> MistralPrediction: |
1693 | 1692 | """Fixture providing a MistralPrediction object.""" |
1694 | | - return MistralPrediction(content="bar") |
| 1693 | + return MistralPrediction(content='bar') |
1695 | 1694 |
|
1696 | 1695 |
|
1697 | 1696 | @pytest.mark.parametrize( |
1698 | | - "input_value,expected_content", |
| 1697 | + 'input_value,expected_content', |
1699 | 1698 | [ |
1700 | | - ("plain text", "plain text"), |
1701 | | - ("example_prediction", "bar"), |
1702 | | - ("example_dict", "foo"), |
| 1699 | + ('plain text', 'plain text'), |
| 1700 | + ('example_prediction', 'bar'), |
| 1701 | + ('example_dict', 'foo'), |
1703 | 1702 | (None, None), |
1704 | 1703 | ], |
1705 | 1704 | ) |
1706 | | -def test_map_setting_prediction_valid(request, input_value, expected_content): |
| 1705 | +def test_map_setting_prediction_valid(request: pytest.FixtureRequest, input_value: str, expected_content: str | None): |
1707 | 1706 | """ |
1708 | 1707 | Accepted input types (str, dict, MistralPrediction, None) must be correctlyconverted to a MistralPrediction or None. |
1709 | 1708 | """ |
1710 | 1709 | # If the parameter is a fixture name, resolve it using request |
1711 | | - if isinstance(input_value, str) and input_value in {"example_dict", "example_prediction"}: |
1712 | | - input_value = request.getfixturevalue(input_value) |
| 1710 | + resolved_value: str | MistralPredictionTypedDict | MistralPrediction | None = input_value |
| 1711 | + if isinstance(input_value, str) and input_value in {'example_dict', 'example_prediction'}: |
| 1712 | + resolved_value = request.getfixturevalue(input_value) |
1713 | 1713 |
|
1714 | | - result = MistralModel._map_setting_prediction(input_value) # pyright: ignore[reportPrivateUsage] |
| 1714 | + result = MistralModel._map_setting_prediction(resolved_value) # pyright: ignore[reportPrivateUsage] |
1715 | 1715 |
|
1716 | | - if input_value is None: |
| 1716 | + if resolved_value is None: |
1717 | 1717 | assert result is None |
1718 | 1718 | else: |
1719 | 1719 | assert isinstance(result, MistralPrediction) |
1720 | 1720 | assert result.content == expected_content |
1721 | | -# ----------------------------------------------------- |
| 1721 | + |
| 1722 | + |
1722 | 1723 | def test_generate_user_output_format_complex(mistral_api_key: str): |
1723 | 1724 | """ |
1724 | 1725 | Single test that includes properties exercising every branch |
@@ -2304,21 +2305,19 @@ async def test_mistral_model_thinking_part_iter(allow_model_requests: None, mist |
2304 | 2305 | ] |
2305 | 2306 | ) |
2306 | 2307 |
|
| 2308 | + |
2307 | 2309 | @pytest.mark.vcr() |
2308 | 2310 | async def test_mistral_chat_with_prediction(allow_model_requests: None, mistral_api_key: str): |
2309 | 2311 | """Test chat completion with prediction parameter using a math query.""" |
2310 | 2312 | from pydantic_ai.models.mistral import MistralModelSettings |
2311 | 2313 |
|
2312 | | - model = MistralModel( |
2313 | | - 'mistral-small-latest', |
2314 | | - provider=MistralProvider(api_key=mistral_api_key) |
2315 | | - ) |
2316 | | - prediction = "The result of 21+21=99" |
2317 | | - settings = MistralModelSettings(prediction=prediction) |
| 2314 | + model = MistralModel('mistral-large-2411', provider=MistralProvider(api_key=mistral_api_key)) |
| 2315 | + prediction = 'The result of 21+21=99' |
| 2316 | + settings = MistralModelSettings(mistral_prediction=prediction) |
2318 | 2317 | agent = Agent(model, model_settings=settings) |
2319 | 2318 |
|
2320 | | - result = await agent.run(['Correct only the math, respond with no explanation, no formatting.',"The result of 21+21=99"]) |
| 2319 | + result = await agent.run(['Correct the math, keep everything else. No explanation, no formatting.', prediction]) |
2321 | 2320 |
|
2322 | 2321 | # Verify that the response uses the expected prediction |
2323 | | - assert 'The result of 21+21=' in result.output |
2324 | 2322 | assert '42' in result.output |
| 2323 | + assert 'The result of 21+21=' in result.output |
0 commit comments