Skip to content

Commit 597eff1

Browse files
Formatting
1 parent 5720a4b commit 597eff1

File tree

15 files changed

+160
-91
lines changed

15 files changed

+160
-91
lines changed

src/neo4j_graphrag/generation/graphrag.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
import json
1817
import logging
1918
import warnings
2019
from typing import Any, Optional
@@ -25,7 +24,11 @@
2524
RagInitializationError,
2625
SearchValidationError,
2726
)
28-
from neo4j_graphrag.generation.prompts import RagTemplate, ChatSummaryTemplate, ConversationTemplate
27+
from neo4j_graphrag.generation.prompts import (
28+
RagTemplate,
29+
ChatSummaryTemplate,
30+
ConversationTemplate,
31+
)
2932
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
3033
from neo4j_graphrag.llm import LLMInterface
3134
from neo4j_graphrag.retrievers.base import Retriever
@@ -142,10 +145,14 @@ def search(
142145
if return_context:
143146
result["retriever_result"] = retriever_result
144147
return RagResultModel(**result)
145-
148+
146149
def build_query(self, query_text: str, chat_history: list[dict[str, str]]) -> str:
147150
if chat_history:
148-
summarization_prompt = ChatSummaryTemplate().format(chat_history=chat_history)
151+
summarization_prompt = ChatSummaryTemplate().format(
152+
chat_history=chat_history
153+
)
149154
summary = self.llm.invoke(summarization_prompt).content
150-
return ConversationTemplate().format(summary=summary, current_query=query_text)
155+
return ConversationTemplate().format(
156+
summary=summary, current_query=query_text
157+
)
151158
return query_text

src/neo4j_graphrag/generation/prompts.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def format(
192192
text: str = "",
193193
) -> str:
194194
return super().format(text=text, schema=schema, examples=examples)
195-
195+
196196

197197
class ChatSummaryTemplate(PromptTemplate):
198198
DEFAULT_TEMPLATE = """
@@ -203,8 +203,11 @@ class ChatSummaryTemplate(PromptTemplate):
203203
EXPECTED_INPUTS = ["chat_history"]
204204

205205
def format(self, chat_history: list[dict[str, str]]) -> str:
206-
message_list = [': '.join([f"{value}" for _, value in message.items()]) for message in chat_history]
207-
history = '\n'.join(message_list)
206+
message_list = [
207+
": ".join([f"{value}" for _, value in message.items()])
208+
for message in chat_history
209+
]
210+
history = "\n".join(message_list)
208211
return super().format(chat_history=history)
209212

210213

src/neo4j_graphrag/llm/anthropic_llm.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def get_messages(self, input: str, chat_history: list) -> Iterable[MessageParam]
8282
messages.append(UserMessage(content=input).model_dump())
8383
return messages
8484

85-
def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None) -> LLMResponse:
85+
def invoke(
86+
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
87+
) -> LLMResponse:
8688
"""Sends text to the LLM and returns a response.
8789
8890
Args:
@@ -95,9 +97,9 @@ def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None
9597
try:
9698
messages = self.get_messages(input, chat_history)
9799
response = self.client.messages.create(
98-
model = self.model_name,
99-
system = self.system_instruction,
100-
messages = messages,
100+
model=self.model_name,
101+
system=self.system_instruction,
102+
messages=messages,
101103
**self.model_params,
102104
)
103105
return LLMResponse(content=response.content)
@@ -119,9 +121,9 @@ async def ainvoke(
119121
try:
120122
messages = self.get_messages(input, chat_history)
121123
response = await self.async_client.messages.create(
122-
model = self.model_name,
123-
system = self.system_instruction,
124-
messages = messages,
124+
model=self.model_name,
125+
system=self.system_instruction,
126+
messages=messages,
125127
**self.model_params,
126128
)
127129
return LLMResponse(content=response.content)

src/neo4j_graphrag/llm/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def __init__(
4242
self.system_instruction = system_instruction
4343

4444
@abstractmethod
45-
def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None) -> LLMResponse:
45+
def invoke(
46+
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
47+
) -> LLMResponse:
4648
"""Sends a text input to the LLM and retrieves a response.
4749
4850
Args:

src/neo4j_graphrag/llm/cohere_llm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from typing import Any, Iterable, Optional
17+
from typing import Any, Optional
1818
from pydantic import ValidationError
1919

2020
from neo4j_graphrag.exceptions import LLMGenerationError
2121
from neo4j_graphrag.llm.base import LLMInterface
22-
from neo4j_graphrag.llm.types import LLMResponse, MessageList, SystemMessage, UserMessage
22+
from neo4j_graphrag.llm.types import (
23+
LLMResponse,
24+
MessageList,
25+
SystemMessage,
26+
UserMessage,
27+
)
2328

2429
try:
2530
import cohere
@@ -69,7 +74,7 @@ def __init__(
6974
self.client = cohere.ClientV2(**kwargs)
7075
self.async_client = cohere.AsyncClientV2(**kwargs)
7176

72-
def get_messages(self, input: str, chat_history: list) -> ChatMessages: # type: ignore
77+
def get_messages(self, input: str, chat_history: list) -> ChatMessages: # type: ignore
7378
messages = []
7479
if self.system_instruction:
7580
messages.append(SystemMessage(content=self.system_instruction).model_dump())
@@ -82,7 +87,9 @@ def get_messages(self, input: str, chat_history: list) -> ChatMessages: # type:
8287
messages.append(UserMessage(content=input).model_dump())
8388
return messages
8489

85-
def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None) -> LLMResponse:
90+
def invoke(
91+
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
92+
) -> LLMResponse:
8693
"""Sends text to the LLM and returns a response.
8794
8895
Args:

src/neo4j_graphrag/llm/mistralai_llm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515
from __future__ import annotations
1616

1717
import os
18-
from typing import Any, Optional, Union
18+
from typing import Any, Optional
1919
from pydantic import ValidationError
2020

2121
from neo4j_graphrag.exceptions import LLMGenerationError
2222
from neo4j_graphrag.llm.base import LLMInterface
23-
from neo4j_graphrag.llm.types import LLMResponse, MessageList, SystemMessage, UserMessage
23+
from neo4j_graphrag.llm.types import (
24+
LLMResponse,
25+
MessageList,
26+
SystemMessage,
27+
UserMessage,
28+
)
2429

2530
try:
2631
from mistralai import Mistral, Messages
@@ -58,7 +63,7 @@ def __init__(
5863
if api_key is None:
5964
api_key = os.getenv("MISTRAL_API_KEY", "")
6065
self.client = Mistral(api_key=api_key, **kwargs)
61-
66+
6267
def get_messages(self, input: str, chat_history: list) -> list[Messages]:
6368
messages = []
6469
if self.system_instruction:
@@ -72,7 +77,9 @@ def get_messages(self, input: str, chat_history: list) -> list[Messages]:
7277
messages.append(UserMessage(content=input).model_dump())
7378
return messages
7479

75-
def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None) -> LLMResponse:
80+
def invoke(
81+
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
82+
) -> LLMResponse:
7683
"""Sends a text input to the Mistral chat completion model
7784
and returns the response's content.
7885

src/neo4j_graphrag/llm/openai_llm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@ def __init__(
6161
super().__init__(model_name, model_params, system_instruction)
6262

6363
def get_messages(
64-
self,
65-
input: str,
66-
chat_history: list,
64+
self, input: str, chat_history: list
6765
) -> Iterable[ChatCompletionMessageParam]:
6866
messages = []
6967
if self.system_instruction:
@@ -77,7 +75,9 @@ def get_messages(
7775
messages.append(UserMessage(content=input).model_dump())
7876
return messages
7977

80-
def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None) -> LLMResponse:
78+
def invoke(
79+
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
80+
) -> LLMResponse:
8181
"""Sends a text input to the OpenAI chat completion model
8282
and returns the response's content.
8383

src/neo4j_graphrag/llm/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class BaseMessage(BaseModel):
1414
class UserMessage(BaseMessage):
1515
role: Literal["user"] = "user"
1616

17+
1718
class SystemMessage(BaseMessage):
1819
role: Literal["system"] = "system"
1920

src/neo4j_graphrag/llm/vertexai_llm.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,25 +76,34 @@ def __init__(
7676
model_name=model_name, system_instruction=[system_instruction], **kwargs
7777
)
7878

79-
8079
def get_messages(self, input: str, chat_history: list[str]) -> list[Content]:
8180
messages = []
8281
if chat_history:
8382
try:
8483
MessageList(messages=chat_history)
8584
except ValidationError as e:
8685
raise LLMGenerationError(e.errors()) from e
87-
86+
8887
for message in chat_history:
8988
if message.get("role") == "user":
90-
messages.append(Content(role="user", parts=[Part.from_text(message.get("content"))]))
89+
messages.append(
90+
Content(
91+
role="user", parts=[Part.from_text(message.get("content"))]
92+
)
93+
)
9194
elif message.get("role") == "assistant":
92-
messages.append(Content(role="model", parts=[Part.from_text(message.get("content"))]))
95+
messages.append(
96+
Content(
97+
role="model", parts=[Part.from_text(message.get("content"))]
98+
)
99+
)
93100

94101
messages.append(Content(role="user", parts=[Part.from_text(input)]))
95102
return messages
96103

97-
def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None) -> LLMResponse:
104+
def invoke(
105+
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
106+
) -> LLMResponse:
98107
"""Sends text to the LLM and returns a response.
99108
100109
Args:

tests/unit/llm/test_anthropic_llm.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,8 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
import sys
17-
from typing import Generator
1816
from unittest.mock import AsyncMock, MagicMock, Mock, patch
1917

20-
import anthropic
2118
import pytest
2219
from neo4j_graphrag.exceptions import LLMGenerationError
2320
from neo4j_graphrag.llm.anthropic_llm import AnthropicLLM
@@ -54,17 +51,21 @@ def test_anthropic_invoke_with_chat_history_happy_path(mock_anthropic: Mock) ->
5451
)
5552
model_params = {"temperature": 0.3}
5653
system_instruction = "You are a helpful assistant."
57-
llm = AnthropicLLM("claude-3-opus-20240229", model_params=model_params, system_instruction=system_instruction)
54+
llm = AnthropicLLM(
55+
"claude-3-opus-20240229",
56+
model_params=model_params,
57+
system_instruction=system_instruction,
58+
)
5859
chat_history = [
5960
{"role": "user", "content": "When does the sun come up in the summer?"},
6061
{"role": "assistant", "content": "Usually around 6am."},
6162
]
6263
question = "What about next season?"
63-
64+
6465
response = llm.invoke(question, chat_history)
6566
assert response.content == "generated text"
6667
chat_history.append({"role": "user", "content": question})
67-
llm.client.messages.create.assert_called_once_with( # type: ignore
68+
llm.client.messages.create.assert_called_once_with(
6869
messages=chat_history,
6970
model="claude-3-opus-20240229",
7071
system=system_instruction,
@@ -73,13 +74,19 @@ def test_anthropic_invoke_with_chat_history_happy_path(mock_anthropic: Mock) ->
7374

7475

7576
@patch("neo4j_graphrag.llm.anthropic_llm.anthropic.Anthropic")
76-
def test_anthropic_invoke_with_chat_history_validation_error(mock_anthropic: Mock) -> None:
77+
def test_anthropic_invoke_with_chat_history_validation_error(
78+
mock_anthropic: Mock,
79+
) -> None:
7780
mock_anthropic.return_value.messages.create.return_value = MagicMock(
7881
content="generated text"
7982
)
8083
model_params = {"temperature": 0.3}
8184
system_instruction = "You are a helpful assistant."
82-
llm = AnthropicLLM("claude-3-opus-20240229", model_params=model_params, system_instruction=system_instruction)
85+
llm = AnthropicLLM(
86+
"claude-3-opus-20240229",
87+
model_params=model_params,
88+
system_instruction=system_instruction,
89+
)
8390
chat_history = [
8491
{"role": "human", "content": "When does the sun come up in the summer?"},
8592
{"role": "assistant", "content": "Usually around 6am."},

0 commit comments

Comments
 (0)