Skip to content

Commit ca9d4e4

Browse files
authored
mistralai: support method="json_schema" in structured output (#29461)
https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/
1 parent e120378 commit ca9d4e4

File tree

2 files changed

+137
-12
lines changed

2 files changed

+137
-12
lines changed

libs/partners/mistralai/langchain_mistralai/chat_models.py

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,9 @@ def with_structured_output(
686686
self,
687687
schema: Optional[Union[Dict, Type]] = None,
688688
*,
689-
method: Literal["function_calling", "json_mode"] = "function_calling",
689+
method: Literal[
690+
"function_calling", "json_mode", "json_schema"
691+
] = "function_calling",
690692
include_raw: bool = False,
691693
**kwargs: Any,
692694
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
@@ -710,13 +712,25 @@ def with_structured_output(
710712
711713
Added support for TypedDict class.
712714
713-
method:
714-
The method for steering model generation, either "function_calling"
715-
or "json_mode". If "function_calling" then the schema will be converted
716-
to an OpenAI function and the returned model will make use of the
717-
function-calling API. If "json_mode" then OpenAI's JSON mode will be
718-
used. Note that if using "json_mode" then you must include instructions
719-
for formatting the output into the desired schema into the model call.
715+
method: The method for steering model generation, one of:
716+
717+
- "function_calling":
718+
Uses Mistral's
719+
`function-calling feature <https://docs.mistral.ai/capabilities/function_calling/>`_.
720+
- "json_schema":
721+
Uses Mistral's
722+
`structured output feature <https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/>`_.
723+
- "json_mode":
724+
Uses Mistral's
725+
`JSON mode <https://docs.mistral.ai/capabilities/structured-output/json_mode/>`_.
726+
Note that if using JSON mode then you
727+
must include instructions for formatting the output into the
728+
desired schema into the model call.
729+
730+
.. versionchanged:: 0.2.5
731+
732+
Added method="json_schema"
733+
720734
include_raw:
721735
If False then only the parsed structured output is returned. If
722736
an error occurs during model output parsing it will be raised. If True
@@ -877,11 +891,11 @@ class AnswerWithJustification(BaseModel):
877891
878892
structured_llm.invoke(
879893
"Answer the following question. "
880-
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
894+
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"
881895
"What's heavier a pound of bricks or a pound of feathers?"
882896
)
883897
# -> {
884-
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
898+
# 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),
885899
# 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),
886900
# 'parsing_error': None
887901
# }
@@ -893,17 +907,18 @@ class AnswerWithJustification(BaseModel):
893907
894908
structured_llm.invoke(
895909
"Answer the following question. "
896-
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
910+
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"
897911
"What's heavier a pound of bricks or a pound of feathers?"
898912
)
899913
# -> {
900-
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
914+
# 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),
901915
# 'parsed': {
902916
# 'answer': 'They are both the same weight.',
903917
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'
904918
# },
905919
# 'parsing_error': None
906920
# }
921+
907922
""" # noqa: E501
908923
if kwargs:
909924
raise ValueError(f"Received unsupported arguments {kwargs}")
@@ -934,6 +949,20 @@ class AnswerWithJustification(BaseModel):
934949
if is_pydantic_schema
935950
else JsonOutputParser()
936951
)
952+
elif method == "json_schema":
953+
if schema is None:
954+
raise ValueError(
955+
"schema must be specified when method is 'json_schema'. "
956+
"Received None."
957+
)
958+
response_format = _convert_to_openai_response_format(schema, strict=True)
959+
llm = self.bind(response_format=response_format)
960+
961+
output_parser = (
962+
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
963+
if is_pydantic_schema
964+
else JsonOutputParser()
965+
)
937966
if include_raw:
938967
parser_assign = RunnablePassthrough.assign(
939968
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
@@ -969,3 +998,38 @@ def is_lc_serializable(cls) -> bool:
969998
def get_lc_namespace(cls) -> List[str]:
970999
"""Get the namespace of the langchain object."""
9711000
return ["langchain", "chat_models", "mistralai"]
1001+
1002+
1003+
def _convert_to_openai_response_format(
1004+
schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None
1005+
) -> Dict:
1006+
"""Same as in ChatOpenAI, but don't pass through Pydantic BaseModels."""
1007+
if (
1008+
isinstance(schema, dict)
1009+
and "json_schema" in schema
1010+
and schema.get("type") == "json_schema"
1011+
):
1012+
response_format = schema
1013+
elif isinstance(schema, dict) and "name" in schema and "schema" in schema:
1014+
response_format = {"type": "json_schema", "json_schema": schema}
1015+
else:
1016+
if strict is None:
1017+
if isinstance(schema, dict) and isinstance(schema.get("strict"), bool):
1018+
strict = schema["strict"]
1019+
else:
1020+
strict = False
1021+
function = convert_to_openai_tool(schema, strict=strict)["function"]
1022+
function["schema"] = function.pop("parameters")
1023+
response_format = {"type": "json_schema", "json_schema": function}
1024+
1025+
if strict is not None and strict is not response_format["json_schema"].get(
1026+
"strict"
1027+
):
1028+
msg = (
1029+
f"Output schema already has 'strict' value set to "
1030+
f"{schema['json_schema']['strict']} but 'strict' also passed in to "
1031+
f"with_structured_output as {strict}. Please make sure that "
1032+
f"'strict' is only specified in one place."
1033+
)
1034+
raise ValueError(msg)
1035+
return response_format

libs/partners/mistralai/tests/integration_tests/test_chat_models.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import json
44
from typing import Any, Optional
55

6+
import pytest
67
from langchain_core.messages import (
78
AIMessage,
89
AIMessageChunk,
910
BaseMessageChunk,
1011
HumanMessage,
1112
)
1213
from pydantic import BaseModel
14+
from typing_extensions import TypedDict
1315

1416
from langchain_mistralai.chat_models import ChatMistralAI
1517

@@ -176,6 +178,65 @@ class Person(BaseModel):
176178
chunk_num += 1
177179

178180

181+
class Book(BaseModel):
182+
name: str
183+
authors: list[str]
184+
185+
186+
class BookDict(TypedDict):
187+
name: str
188+
authors: list[str]
189+
190+
191+
def _check_parsed_result(result: Any, schema: Any) -> None:
192+
if schema == Book:
193+
assert isinstance(result, Book)
194+
elif schema == BookDict:
195+
assert all(key in ["name", "authors"] for key in result.keys())
196+
197+
198+
@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])
199+
def test_structured_output_json_schema(schema: Any) -> None:
200+
llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg]
201+
structured_llm = llm.with_structured_output(schema, method="json_schema")
202+
203+
messages = [
204+
{"role": "system", "content": "Extract the book's information."},
205+
{
206+
"role": "user",
207+
"content": "I recently read 'To Kill a Mockingbird' by Harper Lee.",
208+
},
209+
]
210+
# Test invoke
211+
result = structured_llm.invoke(messages)
212+
_check_parsed_result(result, schema)
213+
214+
# Test stream
215+
for chunk in structured_llm.stream(messages):
216+
_check_parsed_result(chunk, schema)
217+
218+
219+
@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])
220+
async def test_structured_output_json_schema_async(schema: Any) -> None:
221+
llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg]
222+
structured_llm = llm.with_structured_output(schema, method="json_schema")
223+
224+
messages = [
225+
{"role": "system", "content": "Extract the book's information."},
226+
{
227+
"role": "user",
228+
"content": "I recently read 'To Kill a Mockingbird' by Harper Lee.",
229+
},
230+
]
231+
# Test invoke
232+
result = await structured_llm.ainvoke(messages)
233+
_check_parsed_result(result, schema)
234+
235+
# Test stream
236+
async for chunk in structured_llm.astream(messages):
237+
_check_parsed_result(chunk, schema)
238+
239+
179240
def test_tool_call() -> None:
180241
llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg]
181242

0 commit comments

Comments
 (0)