Skip to content

Commit a362fd3

Browse files
Ollama
* plus added the `options` parameter to the ollama `chat` call
1 parent 6288907 commit a362fd3

File tree

2 files changed

+141
-34
lines changed

2 files changed

+141
-34
lines changed

src/neo4j_graphrag/llm/ollama_llm.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,36 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Any, Optional
15+
from typing import Any, Iterable, Optional
16+
17+
from pydantic import ValidationError
1618

1719
from neo4j_graphrag.exceptions import LLMGenerationError
1820

1921
from .base import LLMInterface
20-
from .types import LLMResponse
22+
from .types import LLMResponse, SystemMessage, UserMessage, MessageList
23+
24+
try:
25+
import ollama
26+
from ollama import Message
27+
except ImportError:
28+
ollama = None
2129

2230

2331
class OllamaLLM(LLMInterface):
2432
def __init__(
2533
self,
2634
model_name: str,
2735
model_params: Optional[dict[str, Any]] = None,
36+
system_instruction: Optional[str] = None,
2837
**kwargs: Any,
2938
):
30-
try:
31-
import ollama
32-
except ImportError:
39+
if ollama is None:
3340
raise ImportError(
3441
"Could not import ollama Python client. "
3542
"Please install it with `pip install ollama`."
3643
)
37-
super().__init__(model_name, model_params, **kwargs)
44+
super().__init__(model_name, model_params, system_instruction, **kwargs)
3845
self.ollama = ollama
3946
self.client = ollama.Client(
4047
**kwargs,
@@ -43,32 +50,43 @@ def __init__(
4350
**kwargs,
4451
)
4552

46-
def invoke(self, input: str) -> LLMResponse:
53+
def get_messages(
54+
self, input: str, chat_history: Optional[list[Any]] = None
55+
) -> Iterable[Message]:
56+
messages = []
57+
if self.system_instruction:
58+
messages.append(SystemMessage(content=self.system_instruction).model_dump())
59+
if chat_history:
60+
try:
61+
MessageList(messages=chat_history)
62+
except ValidationError as e:
63+
raise LLMGenerationError(e.errors()) from e
64+
messages.extend(chat_history)
65+
messages.append(UserMessage(content=input).model_dump())
66+
return messages
67+
68+
def invoke(
69+
self, input: str, chat_history: Optional[list[Any]] = None
70+
) -> LLMResponse:
4771
try:
4872
response = self.client.chat(
4973
model=self.model_name,
50-
messages=[
51-
{
52-
"role": "user",
53-
"content": input,
54-
},
55-
],
74+
messages=self.get_messages(input, chat_history),
75+
options=self.model_params,
5676
)
5777
content = response.message.content or ""
5878
return LLMResponse(content=content)
5979
except self.ollama.ResponseError as e:
6080
raise LLMGenerationError(e)
6181

62-
async def ainvoke(self, input: str) -> LLMResponse:
82+
async def ainvoke(
83+
self, input: str, chat_history: Optional[list[Any]] = None
84+
) -> LLMResponse:
6385
try:
6486
response = await self.async_client.chat(
6587
model=self.model_name,
66-
messages=[
67-
{
68-
"role": "user",
69-
"content": input,
70-
},
71-
],
88+
messages=self.get_messages(input, chat_history),
89+
options=self.model_params,
7290
)
7391
content = response.message.content or ""
7492
return LLMResponse(content=content)

tests/unit/llm/test_ollama_llm.py

Lines changed: 103 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,124 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from typing import Any
1516
from unittest.mock import MagicMock, Mock, patch
1617

1718
import ollama
1819
import pytest
20+
from neo4j_graphrag.exceptions import LLMGenerationError
1921
from neo4j_graphrag.llm import LLMResponse
2022
from neo4j_graphrag.llm.ollama_llm import OllamaLLM
2123

2224

23-
def get_mock_ollama() -> MagicMock:
24-
mock = MagicMock()
25-
mock.ResponseError = ollama.ResponseError
26-
return mock
27-
28-
29-
@patch("builtins.__import__", side_effect=ImportError)
30-
def test_ollama_llm_missing_dependency(mock_import: Mock) -> None:
25+
@patch("neo4j_graphrag.llm.ollama_llm.ollama", None)
26+
def test_ollama_llm_missing_dependency() -> None:
3127
with pytest.raises(ImportError):
3228
OllamaLLM(model_name="gpt-4o")
3329

3430

35-
@patch("builtins.__import__")
36-
def test_ollama_llm_happy_path(mock_import: Mock) -> None:
37-
mock_ollama = get_mock_ollama()
38-
mock_import.return_value = mock_ollama
31+
@patch("neo4j_graphrag.llm.ollama_llm.ollama")
32+
def test_ollama_llm_happy_path(mock_ollama: Mock) -> None:
33+
mock_ollama.Client.return_value.chat.return_value = MagicMock(
34+
message=MagicMock(content="ollama chat response"),
35+
)
36+
model = "gpt"
37+
model_params = {"temperature": 0.3}
38+
system_instruction = "You are a helpful assistant."
39+
question = "What is graph RAG?"
40+
llm = OllamaLLM(
41+
model,
42+
model_params=model_params,
43+
system_instruction=system_instruction,
44+
)
45+
46+
res = llm.invoke(question)
47+
assert isinstance(res, LLMResponse)
48+
assert res.content == "ollama chat response"
49+
messages = [
50+
{"role": "system", "content": system_instruction},
51+
{"role": "user", "content": question},
52+
]
53+
llm.client.chat.assert_called_once_with(
54+
model=model, messages=messages, options=model_params
55+
)
56+
57+
58+
@patch("neo4j_graphrag.llm.ollama_llm.ollama")
59+
def test_ollama_invoke_with_chat_history_happy_path(mock_ollama: Mock) -> None:
60+
mock_ollama.Client.return_value.chat.return_value = MagicMock(
61+
message=MagicMock(content="ollama chat response"),
62+
)
63+
model = "gpt"
64+
model_params = {"temperature": 0.3}
65+
system_instruction = "You are a helpful assistant."
66+
llm = OllamaLLM(
67+
model,
68+
model_params=model_params,
69+
system_instruction=system_instruction,
70+
)
71+
chat_history = [
72+
{"role": "user", "content": "When does the sun come up in the summer?"},
73+
{"role": "assistant", "content": "Usually around 6am."},
74+
]
75+
question = "What about next season?"
76+
77+
response = llm.invoke(question, chat_history)
78+
assert response.content == "ollama chat response"
79+
messages = [{"role": "system", "content": system_instruction}]
80+
messages.extend(chat_history)
81+
messages.append({"role": "user", "content": question})
82+
llm.client.chat.assert_called_once_with(
83+
model=model, messages=messages, options=model_params
84+
)
85+
86+
87+
@patch("neo4j_graphrag.llm.ollama_llm.ollama")
88+
def test_ollama_invoke_with_chat_history_validation_error(
89+
mock_ollama: Mock,
90+
) -> None:
3991
mock_ollama.Client.return_value.chat.return_value = MagicMock(
4092
message=MagicMock(content="ollama chat response"),
4193
)
42-
llm = OllamaLLM(model_name="gpt")
94+
mock_ollama.ResponseError = ollama.ResponseError
95+
model = "gpt"
96+
model_params = {"temperature": 0.3}
97+
system_instruction = "You are a helpful assistant."
98+
llm = OllamaLLM(
99+
model,
100+
model_params=model_params,
101+
system_instruction=system_instruction,
102+
)
103+
chat_history = [
104+
{"role": "human", "content": "When does the sun come up in the summer?"},
105+
{"role": "assistant", "content": "Usually around 6am."},
106+
]
107+
question = "What about next season?"
108+
109+
with pytest.raises(LLMGenerationError) as exc_info:
110+
llm.invoke(question, chat_history)
111+
assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value)
112+
113+
114+
@pytest.mark.asyncio
115+
@patch("neo4j_graphrag.llm.ollama_llm.ollama")
116+
async def test_ollama_ainvoke_happy_path(mock_ollama: Mock) -> None:
117+
async def mock_chat_async(*args: Any, **kwargs: Any) -> MagicMock:
118+
return MagicMock(
119+
message=MagicMock(content="ollama chat response"),
120+
)
121+
122+
mock_ollama.AsyncClient.return_value.chat = mock_chat_async
123+
model = "gpt"
124+
model_params = {"temperature": 0.3}
125+
system_instruction = "You are a helpful assistant."
126+
question = "What is graph RAG?"
127+
llm = OllamaLLM(
128+
model,
129+
model_params=model_params,
130+
system_instruction=system_instruction,
131+
)
43132

44-
res = llm.invoke("my text")
133+
res = await llm.ainvoke(question)
45134
assert isinstance(res, LLMResponse)
46135
assert res.content == "ollama chat response"

0 commit comments

Comments
 (0)