Skip to content

Commit 72f4de5

Browse files
Mistral
1 parent b8910df commit 72f4de5

File tree

2 files changed

+82
-20
lines changed

2 files changed

+82
-20
lines changed

src/neo4j_graphrag/llm/mistralai_llm.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,15 @@
1616

1717
import os
1818
from typing import Any, Optional, Union
19+
from pydantic import ValidationError
1920

20-
from ..exceptions import LLMGenerationError
21-
from .base import LLMInterface
22-
from .types import LLMResponse
21+
from neo4j_graphrag.exceptions import LLMGenerationError
22+
from neo4j_graphrag.llm.base import LLMInterface
23+
from neo4j_graphrag.llm.types import LLMResponse, MessageList, SystemMessage, UserMessage
2324

2425
try:
25-
from mistralai import Mistral
26-
from mistralai.models.assistantmessage import AssistantMessage
26+
from mistralai import Mistral, Messages
2727
from mistralai.models.sdkerror import SDKError
28-
from mistralai.models.systemmessage import SystemMessage
29-
from mistralai.models.toolmessage import ToolMessage
30-
from mistralai.models.usermessage import UserMessage
31-
32-
MessageType = Union[AssistantMessage, SystemMessage, ToolMessage, UserMessage]
3328
except ImportError:
3429
Mistral = None # type: ignore
3530
SDKError = None # type: ignore
@@ -63,9 +58,19 @@ def __init__(
6358
if api_key is None:
6459
api_key = os.getenv("MISTRAL_API_KEY", "")
6560
self.client = Mistral(api_key=api_key, **kwargs)
66-
67-
def get_messages(self, input: str) -> list[MessageType]:
68-
return [UserMessage(content=input)]
61+
62+
def get_messages(self, input: str, chat_history: list) -> list[Messages]:
63+
messages = []
64+
if self.system_instruction:
65+
messages.append(SystemMessage(content=self.system_instruction).model_dump())
66+
if chat_history:
67+
try:
68+
MessageList(messages=chat_history)
69+
except ValidationError as e:
70+
raise LLMGenerationError(e.errors()) from e
71+
messages.extend(chat_history)
72+
messages.append(UserMessage(content=input).model_dump())
73+
return messages
6974

7075
def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None) -> LLMResponse:
7176
"""Sends a text input to the Mistral chat completion model
@@ -82,9 +87,10 @@ def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None
8287
LLMGenerationError: If anything goes wrong.
8388
"""
8489
try:
90+
messages = self.get_messages(input, chat_history)
8591
response = self.client.chat.complete(
8692
model=self.model_name,
87-
messages=self.get_messages(input),
93+
messages=messages,
8894
**self.model_params,
8995
)
9096
if response is None or response.choices is None or not response.choices:
@@ -112,9 +118,10 @@ async def ainvoke(
112118
LLMGenerationError: If anything goes wrong.
113119
"""
114120
try:
121+
messages = self.get_messages(input, chat_history)
115122
response = await self.client.chat.complete_async(
116123
model=self.model_name,
117-
messages=self.get_messages(input),
124+
messages=messages,
118125
**self.model_params,
119126
)
120127
if response is None or response.choices is None or not response.choices:
Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222

2323

2424
@patch("neo4j_graphrag.llm.mistralai_llm.Mistral", None)
25-
def test_mistral_ai_llm_missing_dependency() -> None:
25+
def test_mistralai_llm_missing_dependency() -> None:
2626
with pytest.raises(ImportError):
2727
MistralAILLM(model_name="mistral-model")
2828

2929

3030
@patch("neo4j_graphrag.llm.mistralai_llm.Mistral")
31-
def test_mistral_ai_llm_invoke(mock_mistral: Mock) -> None:
31+
def test_mistralai_llm_invoke(mock_mistral: Mock) -> None:
3232
mock_mistral_instance = mock_mistral.return_value
3333

3434
chat_response_mock = MagicMock()
@@ -46,9 +46,64 @@ def test_mistral_ai_llm_invoke(mock_mistral: Mock) -> None:
4646
assert res.content == "mistral response"
4747

4848

49+
@patch("neo4j_graphrag.llm.mistralai_llm.Mistral")
50+
def test_mistralai_llm_invoke_with_chat_history(mock_mistral: Mock) -> None:
51+
mock_mistral_instance = mock_mistral.return_value
52+
chat_response_mock = MagicMock()
53+
chat_response_mock.choices = [
54+
MagicMock(message=MagicMock(content="mistral response"))
55+
]
56+
mock_mistral_instance.chat.complete.return_value = chat_response_mock
57+
model = "mistral-model"
58+
system_instruction = "You are a helpful assistant."
59+
60+
llm = MistralAILLM(model_name=model, system_instruction=system_instruction)
61+
62+
chat_history = [
63+
{"role": "user", "content": "When does the sun come up in the summer?"},
64+
{"role": "assistant", "content": "Usually around 6am."},
65+
]
66+
question = "What about next season?"
67+
res = llm.invoke(question, chat_history)
68+
69+
assert isinstance(res, LLMResponse)
70+
assert res.content == "mistral response"
71+
messages = [{"role": "system", "content": system_instruction}]
72+
messages.extend(chat_history)
73+
messages.append({"role": "user", "content": question})
74+
llm.client.chat.complete.assert_called_once_with(
75+
messages=messages,
76+
model=model,
77+
)
78+
79+
80+
@patch("neo4j_graphrag.llm.mistralai_llm.Mistral")
81+
def test_mistralai_llm_invoke_with_chat_history_validation_error(mock_mistral: Mock) -> None:
82+
mock_mistral_instance = mock_mistral.return_value
83+
chat_response_mock = MagicMock()
84+
chat_response_mock.choices = [
85+
MagicMock(message=MagicMock(content="mistral response"))
86+
]
87+
mock_mistral_instance.chat.complete.return_value = chat_response_mock
88+
model = "mistral-model"
89+
system_instruction = "You are a helpful assistant."
90+
91+
llm = MistralAILLM(model_name=model, system_instruction=system_instruction)
92+
93+
chat_history = [
94+
{"role": "user", "content": "When does the sun come up in the summer?"},
95+
{"role": "monkey", "content": "Usually around 6am."},
96+
]
97+
question = "What about next season?"
98+
99+
with pytest.raises(LLMGenerationError) as exc_info:
100+
llm.invoke(question, chat_history)
101+
assert "Input should be 'user' or 'assistant'" in str(exc_info.value)
102+
103+
49104
@pytest.mark.asyncio
50105
@patch("neo4j_graphrag.llm.mistralai_llm.Mistral")
51-
async def test_mistral_ai_llm_ainvoke(mock_mistral: Mock) -> None:
106+
async def test_mistralai_llm_ainvoke(mock_mistral: Mock) -> None:
52107
mock_mistral_instance = mock_mistral.return_value
53108

54109
async def mock_complete_async(*args: Any, **kwargs: Any) -> MagicMock:
@@ -69,7 +124,7 @@ async def mock_complete_async(*args: Any, **kwargs: Any) -> MagicMock:
69124

70125

71126
@patch("neo4j_graphrag.llm.mistralai_llm.Mistral")
72-
def test_mistral_ai_llm_invoke_sdkerror(mock_mistral: Mock) -> None:
127+
def test_mistralai_llm_invoke_sdkerror(mock_mistral: Mock) -> None:
73128
mock_mistral_instance = mock_mistral.return_value
74129
mock_mistral_instance.chat.complete.side_effect = SDKError("Some error")
75130

@@ -81,7 +136,7 @@ def test_mistral_ai_llm_invoke_sdkerror(mock_mistral: Mock) -> None:
81136

82137
@pytest.mark.asyncio
83138
@patch("neo4j_graphrag.llm.mistralai_llm.Mistral")
84-
async def test_mistral_ai_llm_ainvoke_sdkerror(mock_mistral: Mock) -> None:
139+
async def test_mistralai_llm_ainvoke_sdkerror(mock_mistral: Mock) -> None:
85140
mock_mistral_instance = mock_mistral.return_value
86141

87142
async def mock_complete_async(*args: Any, **kwargs: Any) -> None:

0 commit comments

Comments
 (0)