Skip to content

Commit 5720a4b

Browse files
VertexAI
1 parent 72f4de5 commit 5720a4b

File tree

2 files changed

+88
-6
lines changed

2 files changed

+88
-6
lines changed

src/neo4j_graphrag/llm/vertexai_llm.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515

1616
from typing import Any, Optional
1717

18+
from pydantic import ValidationError
19+
1820
from neo4j_graphrag.exceptions import LLMGenerationError
1921
from neo4j_graphrag.llm.base import LLMInterface
20-
from neo4j_graphrag.llm.types import LLMResponse
22+
from neo4j_graphrag.llm.types import LLMResponse, MessageList
2123

2224
try:
23-
from vertexai.generative_models import GenerativeModel, ResponseValidationError
25+
from vertexai.generative_models import (
26+
GenerativeModel,
27+
ResponseValidationError,
28+
Part,
29+
Content,
30+
)
2431
except ImportError:
2532
GenerativeModel = None
2633
ResponseValidationError = None
@@ -69,6 +76,24 @@ def __init__(
6976
model_name=model_name, system_instruction=[system_instruction], **kwargs
7077
)
7178

79+
80+
def get_messages(self, input: str, chat_history: list[str]) -> list[Content]:
81+
messages = []
82+
if chat_history:
83+
try:
84+
MessageList(messages=chat_history)
85+
except ValidationError as e:
86+
raise LLMGenerationError(e.errors()) from e
87+
88+
for message in chat_history:
89+
if message.get("role") == "user":
90+
messages.append(Content(role="user", parts=[Part.from_text(message.get("content"))]))
91+
elif message.get("role") == "assistant":
92+
messages.append(Content(role="model", parts=[Part.from_text(message.get("content"))]))
93+
94+
messages.append(Content(role="user", parts=[Part.from_text(input)]))
95+
return messages
96+
7297
def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None) -> LLMResponse:
7398
"""Sends text to the LLM and returns a response.
7499
@@ -80,7 +105,8 @@ def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None
80105
LLMResponse: The response from the LLM.
81106
"""
82107
try:
83-
response = self.model.generate_content(input, **self.model_params)
108+
messages = self.get_messages(input, chat_history)
109+
response = self.model.generate_content(messages, **self.model_params)
84110
return LLMResponse(content=response.text)
85111
except ResponseValidationError as e:
86112
raise LLMGenerationError(e)
@@ -98,8 +124,9 @@ async def ainvoke(
98124
LLMResponse: The response from the LLM.
99125
"""
100126
try:
127+
messages = self.get_messages(input, chat_history)
101128
response = await self.model.generate_content_async(
102-
input, **self.model_params
129+
messages, **self.model_params
103130
)
104131
return LLMResponse(content=response.text)
105132
except ResponseValidationError as e:

tests/unit/llm/test_vertexai_llm.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
from unittest import mock
1617
from unittest.mock import AsyncMock, MagicMock, Mock, patch
1718

1819
import pytest
20+
from neo4j_graphrag.exceptions import LLMGenerationError
1921
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM
22+
from vertexai.generative_models import Content, Part
2023

2124

2225
@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel", None)
@@ -36,7 +39,59 @@ def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None:
3639
input_text = "may thy knife chip and shatter"
3740
response = llm.invoke(input_text)
3841
assert response.content == "Return text"
39-
llm.model.generate_content.assert_called_once_with(input_text, **model_params)
42+
llm.model.generate_content.assert_called_once_with([mock.ANY], **model_params)
43+
44+
45+
@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
46+
def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None:
47+
system_instruction = "You are a helpful assistant."
48+
model_name = "gemini-1.5-flash-001"
49+
question = "When does it set?"
50+
chat_history = [
51+
{"role": "user", "content": "When does the sun come up in the summer?"},
52+
{"role": "assistant", "content": "Usually around 6am."},
53+
{"role": "user", "content": "What about next season?"},
54+
{"role": "assistant", "content": "Around 8am."},
55+
]
56+
expected_response = [
57+
Content(
58+
role="user",
59+
parts=[Part.from_text("When does the sun come up in the summer?")],
60+
),
61+
Content(role="model", parts=[Part.from_text("Usually around 6am.")]),
62+
Content(role="user", parts=[
63+
Part.from_text("What about next season?")]),
64+
Content(role="model", parts=[Part.from_text("Around 8am.")]),
65+
Content(role="user", parts=[Part.from_text("When does it set?")]),
66+
]
67+
68+
llm = VertexAILLM(
69+
model_name=model_name, system_instruction=system_instruction
70+
)
71+
response = llm.get_messages(question, chat_history)
72+
73+
GenerativeModelMock.assert_called_once_with(model_name=model_name, system_instruction=[system_instruction])
74+
assert len(response) == len(expected_response)
75+
for actual, expected in zip(response, expected_response):
76+
assert actual.role == expected.role
77+
assert actual.parts[0].text == expected.parts[0].text
78+
79+
80+
@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
81+
def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) -> None:
82+
system_instruction = "You are a helpful assistant."
83+
model_name = "gemini-1.5-flash-001"
84+
question = "hi!"
85+
chat_history = [
86+
{"role": "model", "content": "hello!"},
87+
]
88+
89+
llm = VertexAILLM(
90+
model_name=model_name, system_instruction=system_instruction
91+
)
92+
with pytest.raises(LLMGenerationError) as exc_info:
93+
llm.invoke(question, chat_history)
94+
assert "Input should be 'user' or 'assistant'" in str(exc_info.value)
4095

4196

4297
@pytest.mark.asyncio
@@ -51,4 +106,4 @@ async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> No
51106
input_text = "may thy knife chip and shatter"
52107
response = await llm.ainvoke(input_text)
53108
assert response.content == "Return text"
54-
llm.model.generate_content_async.assert_called_once_with(input_text, **model_params)
109+
llm.model.generate_content_async.assert_called_once_with([mock.ANY], **model_params)

0 commit comments

Comments
 (0)