Skip to content

Commit d24d183

Browse files
authored
Mistral optimised (#396)
1 parent 6ff3619 commit d24d183

File tree

4 files changed

+100
-42
lines changed

4 files changed

+100
-42
lines changed

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from itertools import chain
99
from typing import Any, Callable, Literal, Union
1010

11+
import pydantic_core
1112
from httpx import AsyncClient as AsyncHTTPClient, Timeout
1213
from typing_extensions import assert_never
1314

@@ -39,7 +40,6 @@
3940
)
4041

4142
try:
42-
from json_repair import repair_json
4343
from mistralai import (
4444
UNSET,
4545
CompletionChunk as MistralCompletionChunk,
@@ -198,11 +198,10 @@ async def _stream_completions_create(
198198
"""Create a streaming completion request to the Mistral model."""
199199
response: MistralEventStreamAsync[MistralCompletionEvent] | None
200200
mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
201-
202201
model_settings = model_settings or {}
203202

204203
if self.result_tools and self.function_tools or self.function_tools:
205-
# Function Calling Mode
204+
# Function Calling
206205
response = await self.client.chat.stream_async(
207206
model=str(self.model_name),
208207
messages=mistral_messages,
@@ -218,9 +217,9 @@ async def _stream_completions_create(
218217
elif self.result_tools:
219218
# Json Mode
220219
parameters_json_schemas = [tool.parameters_json_schema for tool in self.result_tools]
221-
222220
user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
223221
mistral_messages.append(user_output_format_message)
222+
224223
response = await self.client.chat.stream_async(
225224
model=str(self.model_name),
226225
messages=mistral_messages,
@@ -270,12 +269,13 @@ def _map_function_and_result_tools_definition(self) -> list[MistralTool] | None:
270269
@staticmethod
271270
def _process_response(response: MistralChatCompletionResponse) -> ModelResponse:
272271
"""Process a non-streamed response, and prepare a message to return."""
272+
assert response.choices, 'Unexpected empty response choice.'
273+
273274
if response.created:
274275
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
275276
else:
276277
timestamp = _now_utc()
277278

278-
assert response.choices, 'Unexpected empty response choice.'
279279
choice = response.choices[0]
280280
content = choice.message.content
281281
tool_calls = choice.message.tool_calls
@@ -546,20 +546,15 @@ def get(self, *, final: bool = False) -> ModelResponse:
546546
calls.append(tool)
547547

548548
elif self._delta_content and self._result_tools:
549-
# NOTE: Params set for the most efficient and fastest way.
550-
output_json = repair_json(self._delta_content, return_objects=True, skip_json_loads=True)
551-
assert isinstance(
552-
output_json, dict
553-
), f'Expected repair_json as type dict, invalid type: {type(output_json)}'
549+
output_json: dict[str, Any] | None = pydantic_core.from_json(
550+
self._delta_content, allow_partial='trailing-strings'
551+
)
554552

555553
if output_json:
556554
for result_tool in self._result_tools.values():
557-
# NOTE: Additional verification to prevent JSON validation to crash in `result.py`
555+
# NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
558556
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
559-
# For example, `return_type=list[str]` expects a 'response' key with value type array of str.
560-
# when `{"response":` then `repair_json` sets `{"response": ""}` (type not found default str)
561-
# when `{"response": {` then `repair_json` sets `{"response": {}}` (type found)
562-
# This ensures it's corrected to `{"response": {}}` and other required parameters and type.
557+
# Example with BaseModel and required fields.
563558
if not self._validate_required_json_schema(output_json, result_tool.parameters_json_schema):
564559
continue
565560

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ openai = ["openai>=1.54.3"]
4646
vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
4747
anthropic = ["anthropic>=0.40.0"]
4848
groq = ["groq>=0.12.0"]
49-
mistral = ["mistralai>=1.2.5", "json-repair>=0.30.3"]
49+
mistral = ["mistralai>=1.2.5"]
5050
logfire = ["logfire>=2.3"]
5151

5252
[dependency-groups]

tests/models/test_mistral.py

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ async def test_request_result_type_with_arguments_str_response(allow_model_reque
544544
#####################
545545

546546

547-
async def test_stream_structured_with_all_typd(allow_model_requests: None):
547+
async def test_stream_structured_with_all_type(allow_model_requests: None):
548548
class MyTypedDict(TypedDict, total=False):
549549
first: str
550550
second: int
@@ -563,19 +563,19 @@ class MyTypedDict(TypedDict, total=False):
563563
'", "second": 2',
564564
),
565565
text_chunk(
566-
'", "bool_value": true',
566+
', "bool_value": true',
567567
),
568568
text_chunk(
569-
'", "nullable_value": null',
569+
', "nullable_value": null',
570570
),
571571
text_chunk(
572-
'", "array_value": ["A", "B", "C"]',
572+
', "array_value": ["A", "B", "C"]',
573573
),
574574
text_chunk(
575-
'", "dict_value": {"A": "A", "B":"B"}',
575+
', "dict_value": {"A": "A", "B":"B"}',
576576
),
577577
text_chunk(
578-
'", "dict_int_value": {"A": 1, "B":2}',
578+
', "dict_int_value": {"A": 1, "B":2}',
579579
),
580580
text_chunk('}'),
581581
chunk([]),
@@ -721,8 +721,8 @@ class MyTypedDict(TypedDict, total=False):
721721
{'first': 'One'},
722722
{'first': 'One'},
723723
{'first': 'One'},
724-
{'first': 'One', 'second': ''},
725-
{'first': 'One', 'second': ''},
724+
{'first': 'One'},
725+
{'first': 'One'},
726726
{'first': 'One', 'second': ''},
727727
{'first': 'One', 'second': 'T'},
728728
{'first': 'One', 'second': 'Tw'},
@@ -828,20 +828,21 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None):
828828
v = [c async for c in result.stream(debounce_by=None)]
829829
assert v == snapshot(
830830
[
831+
[''],
831832
['f'],
832833
['fi'],
833834
['fir'],
834835
['firs'],
835836
['first'],
836837
['first'],
837838
['first'],
838-
['first'],
839+
['first', ''],
839840
['first', 'O'],
840841
['first', 'On'],
841842
['first', 'One'],
842843
['first', 'One'],
843844
['first', 'One'],
844-
['first', 'One'],
845+
['first', 'One', ''],
845846
['first', 'One', 's'],
846847
['first', 'One', 'se'],
847848
['first', 'One', 'sec'],
@@ -850,7 +851,7 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None):
850851
['first', 'One', 'second'],
851852
['first', 'One', 'second'],
852853
['first', 'One', 'second'],
853-
['first', 'One', 'second'],
854+
['first', 'One', 'second', ''],
854855
['first', 'One', 'second', 'T'],
855856
['first', 'One', 'second', 'Tw'],
856857
['first', 'One', 'second', 'Two'],
@@ -869,10 +870,10 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None):
869870
assert result.usage().response_tokens == len(stream)
870871

871872

872-
async def test_stream_result_type_basemodel(allow_model_requests: None):
873+
async def test_stream_result_type_basemodel_with_default_params(allow_model_requests: None):
873874
class MyTypedBaseModel(BaseModel):
874-
first: str = '' # Note: Don't forget to set default values
875-
second: str = ''
875+
first: str = '' # Note: Default, set value.
876+
second: str = '' # Note: Default, set value.
876877

877878
# Given
878879
stream = [
@@ -958,6 +959,79 @@ class MyTypedBaseModel(BaseModel):
958959
assert result.usage().response_tokens == len(stream)
959960

960961

962+
async def test_stream_result_type_basemodel_with_required_params(allow_model_requests: None):
963+
class MyTypedBaseModel(BaseModel):
964+
first: str # Note: Required params
965+
second: str # Note: Required params
966+
967+
# Given
968+
stream = [
969+
text_chunk('{'),
970+
text_chunk('"'),
971+
text_chunk('f'),
972+
text_chunk('i'),
973+
text_chunk('r'),
974+
text_chunk('s'),
975+
text_chunk('t'),
976+
text_chunk('"'),
977+
text_chunk(':'),
978+
text_chunk(' '),
979+
text_chunk('"'),
980+
text_chunk('O'),
981+
text_chunk('n'),
982+
text_chunk('e'),
983+
text_chunk('"'),
984+
text_chunk(','),
985+
text_chunk(' '),
986+
text_chunk('"'),
987+
text_chunk('s'),
988+
text_chunk('e'),
989+
text_chunk('c'),
990+
text_chunk('o'),
991+
text_chunk('n'),
992+
text_chunk('d'),
993+
text_chunk('"'),
994+
text_chunk(':'),
995+
text_chunk(' '),
996+
text_chunk('"'),
997+
text_chunk('T'),
998+
text_chunk('w'),
999+
text_chunk('o'),
1000+
text_chunk('"'),
1001+
text_chunk('}'),
1002+
chunk([]),
1003+
]
1004+
1005+
mock_client = MockMistralAI.create_stream_mock(stream)
1006+
model = MistralModel('mistral-large-latest', client=mock_client)
1007+
agent = Agent(model=model, result_type=MyTypedBaseModel)
1008+
1009+
# When
1010+
async with agent.run_stream('User prompt value') as result:
1011+
# Then
1012+
assert result.is_structured
1013+
assert not result.is_complete
1014+
v = [c async for c in result.stream(debounce_by=None)]
1015+
assert v == snapshot(
1016+
[
1017+
MyTypedBaseModel(first='One', second=''),
1018+
MyTypedBaseModel(first='One', second='T'),
1019+
MyTypedBaseModel(first='One', second='Tw'),
1020+
MyTypedBaseModel(first='One', second='Two'),
1021+
MyTypedBaseModel(first='One', second='Two'),
1022+
MyTypedBaseModel(first='One', second='Two'),
1023+
MyTypedBaseModel(first='One', second='Two'),
1024+
]
1025+
)
1026+
assert result.is_complete
1027+
assert result.usage().request_tokens == 34
1028+
assert result.usage().response_tokens == 34
1029+
assert result.usage().total_tokens == 34
1030+
1031+
# double check cost matches stream count
1032+
assert result.usage().response_tokens == len(stream)
1033+
1034+
9611035
#####################
9621036
## Completion Function call
9631037
#####################
@@ -1693,6 +1767,6 @@ def test_generate_user_output_format_multiple():
16931767
),
16941768
],
16951769
)
1696-
def test_validate_required_json_shema(desc: str, schema: dict[str, Any], data: dict[str, Any], expected: bool) -> None:
1770+
def test_validate_required_json_schema(desc: str, schema: dict[str, Any], data: dict[str, Any], expected: bool) -> None:
16971771
result = MistralStreamStructuredResponse._validate_required_json_schema(data, schema) # pyright: ignore[reportPrivateUsage]
16981772
assert result == expected, f'{desc} — expected {expected}, got {result}'

uv.lock

Lines changed: 0 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)