Skip to content

Commit 62bbea3

Browse files
authored
LangChain Compability Restored (#442)
* langchain compatible llminterfacev2 interface created * langchain compatible methods added to openai * broken test cases sorted * brand new tests for llm interface v2 * vertex ai started to support llm interface v2 * brand new test cases added * invoke with tools method is not mandatory * llminterfacev2 support added to ollama * llm interface v2 supported in mistralai * llm interface v2 support added to cohere * llm interface v2 support added to anthropic * more tests for llminterfacev2 * mypy fixes * test_openai_llm possibly failed because of import in ci cd * Attempt CI/CD-compatible async mock for OpenAILLM tests * Attempt CI/CD-safe async mock for OpenAILLM v1 test * restoring graphrag e2e tests for v1 * existing e2e test for graphrag sorted * create message history only if langchain compatible branch * typo in docstring * avoid repeated langchain compatible check code * avoid to create vector idx and fulltext with same property * avoid to create vector idx and fulltext with same property * make is instance of langchain check prettier * use params from invoke if available similar to LC * define kwargs for invoke in the interface * fixing outer scope definition warning * arg names replaced with LC arg names * some additional docstring for new input args * docstring updated * use (a)invokev2 function names instead of brand new * use (a)invoke_v1 function name instead of legacy invoke * keep rate limit handlers for v2 functions * use warning for deprecated llm interface v1 * resolve conflict after Alex's recent uv change * formatted * mypy problems sorted * initialize llms with llm interface v2 * invoke with tools dropped in llm interface v2 * revert the behaviour back for return_context * e2e tests updated for restored behaviour for return_context * llm can also be lc object, added any for this * users initialize their own llm objects, so we don't have to make llm interface's init same as lc * auto format after manual conflict resolution
1 parent 92f4b14 commit 62bbea3

22 files changed

+3351
-285
lines changed

src/neo4j_graphrag/generation/graphrag.py

Lines changed: 83 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,35 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from __future__ import annotations
1615

16+
# built-in dependencies
17+
from __future__ import annotations
1718
import logging
1819
import warnings
1920
from typing import Any, List, Optional, Union
2021

22+
# 3rd party dependencies
2123
from pydantic import ValidationError
2224

25+
# project dependencies
2326
from neo4j_graphrag.exceptions import (
2427
RagInitializationError,
2528
SearchValidationError,
2629
)
2730
from neo4j_graphrag.generation.prompts import RagTemplate
2831
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
29-
from neo4j_graphrag.llm import LLMInterface
32+
from neo4j_graphrag.llm import LLMInterface, LLMInterfaceV2
33+
from neo4j_graphrag.llm.utils import legacy_inputs_to_messages
3034
from neo4j_graphrag.message_history import MessageHistory
3135
from neo4j_graphrag.retrievers.base import Retriever
3236
from neo4j_graphrag.types import LLMMessage, RetrieverResult
3337
from neo4j_graphrag.utils.logging import prettify
3438

39+
# Set up logger
3540
logger = logging.getLogger(__name__)
3641

3742

43+
# pylint: disable=raise-missing-from
3844
class GraphRAG:
3945
"""Performs a GraphRAG search using a specific retriever
4046
and LLM.
@@ -57,8 +63,10 @@ class GraphRAG:
5763
5864
Args:
5965
retriever (Retriever): The retriever used to find relevant context to pass to the LLM.
60-
llm (LLMInterface): The LLM used to generate the answer.
61-
prompt_template (RagTemplate): The prompt template that will be formatted with context and user question and passed to the LLM.
66+
llm (LLMInterface, LLMInterfaceV2 or LangChain Chat Model): The LLM used to generate
67+
the answer.
68+
prompt_template (RagTemplate): The prompt template that will be formatted with context and
69+
user question and passed to the LLM.
6270
6371
Raises:
6472
RagInitializationError: If validation of the input arguments fail.
@@ -67,7 +75,7 @@ class GraphRAG:
6775
def __init__(
6876
self,
6977
retriever: Retriever,
70-
llm: LLMInterface,
78+
llm: Union[LLMInterface, LLMInterfaceV2, Any],
7179
prompt_template: RagTemplate = RagTemplate(),
7280
):
7381
try:
@@ -93,7 +101,8 @@ def search(
93101
) -> RagResultModel:
94102
"""
95103
.. warning::
96-
The default value of 'return_context' will change from 'False' to 'True' in a future version.
104+
The default value of 'return_context' will change from 'False'
105+
to 'True' in a future version.
97106
98107
99108
This method performs a full RAG search:
@@ -104,24 +113,28 @@ def search(
104113
105114
Args:
106115
query_text (str): The user question.
107-
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
108-
with each message having a specific role assigned.
116+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection
117+
of previous messages, with each message having a specific role assigned.
109118
examples (str): Examples added to the LLM prompt.
110119
retriever_config (Optional[dict]): Parameters passed to the retriever.
111120
search method; e.g.: top_k
112-
return_context (bool): Whether to append the retriever result to the final result (default: False).
113-
response_fallback (Optional[str]): If not null, will return this message instead of calling the LLM if context comes back empty.
121+
return_context (bool): Whether to append the retriever result to the final result
122+
(default: False).
123+
response_fallback (Optional[str]): If not null, will return this message instead
124+
of calling the LLM if context comes back empty.
114125
115126
Returns:
116127
RagResultModel: The LLM-generated answer.
117128
118129
"""
119130
if return_context is None:
120131
warnings.warn(
121-
"The default value of 'return_context' will change from 'False' to 'True' in a future version.",
132+
"The default value of 'return_context' will change from 'False'"
133+
" to 'True' in a future version.",
122134
DeprecationWarning,
123135
)
124136
return_context = False
137+
125138
try:
126139
validated_data = RagSearchModel(
127140
query_text=query_text,
@@ -145,13 +158,31 @@ def search(
145158
prompt = self.prompt_template.format(
146159
query_text=query_text, context=context, examples=validated_data.examples
147160
)
148-
logger.debug(f"RAG: retriever_result={prettify(retriever_result)}")
149-
logger.debug(f"RAG: prompt={prompt}")
150-
llm_response = self.llm.invoke(
151-
prompt,
152-
message_history,
153-
system_instruction=self.prompt_template.system_instructions,
154-
)
161+
162+
logger.debug("RAG: retriever_result=%s", prettify(retriever_result))
163+
logger.debug("RAG: prompt=%s", prompt)
164+
165+
if self.is_langchain_compatible():
166+
# llm interface v2 or langchain chat model
167+
messages = legacy_inputs_to_messages(
168+
prompt=prompt,
169+
message_history=message_history,
170+
system_instruction=self.prompt_template.system_instructions,
171+
)
172+
173+
# langchain chat model compatible invoke
174+
llm_response = self.llm.invoke(
175+
input=messages,
176+
)
177+
elif isinstance(self.llm, LLMInterface):
178+
# may have custom LLMs inherited from V1, keep it for backward compatibility
179+
llm_response = self.llm.invoke(
180+
input=prompt,
181+
message_history=message_history,
182+
system_instruction=self.prompt_template.system_instructions,
183+
)
184+
else:
185+
raise ValueError(f"Type {type(self.llm)} of LLM is not supported.")
155186
answer = llm_response.content
156187
result: dict[str, Any] = {"answer": answer}
157188
if return_context:
@@ -163,18 +194,47 @@ def _build_query(
163194
query_text: str,
164195
message_history: Optional[List[LLMMessage]] = None,
165196
) -> str:
166-
summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words."
197+
"""Builds the final query text, incorporating message history if provided."""
198+
summary_system_message = (
199+
"You are a summarization assistant. "
200+
"Summarize the given text in no more than 300 words."
201+
)
167202
if message_history:
168203
summarization_prompt = self._chat_summary_prompt(
169204
message_history=message_history
170205
)
171-
summary = self.llm.invoke(
172-
input=summarization_prompt,
173-
system_instruction=summary_system_message,
174-
).content
206+
if self.is_langchain_compatible():
207+
messages = legacy_inputs_to_messages(
208+
summarization_prompt,
209+
system_instruction=summary_system_message,
210+
)
211+
summary = self.llm.invoke(
212+
input=messages,
213+
).content
214+
elif isinstance(self.llm, LLMInterface):
215+
summary = self.llm.invoke(
216+
input=summarization_prompt,
217+
system_instruction=summary_system_message,
218+
).content
219+
else:
220+
raise ValueError(f"Type {type(self.llm)} of LLM is not supported.")
221+
175222
return self.conversation_prompt(summary=summary, current_query=query_text)
176223
return query_text
177224

225+
def is_langchain_compatible(self) -> bool:
226+
"""Checks if the LLM is compatible with LangChain."""
227+
if isinstance(self.llm, LLMInterfaceV2):
228+
return True
229+
230+
try:
231+
# langchain-core is an optional dependency
232+
from langchain_core.language_models import BaseChatModel
233+
234+
return isinstance(self.llm, BaseChatModel)
235+
except ImportError:
236+
return False
237+
178238
def _chat_summary_prompt(self, message_history: List[LLMMessage]) -> str:
179239
message_list = [
180240
f"{message['role']}: {message['content']}" for message in message_history

src/neo4j_graphrag/llm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Any
1717

1818
from .anthropic_llm import AnthropicLLM
19-
from .base import LLMInterface
19+
from .base import LLMInterface, LLMInterfaceV2
2020
from .cohere_llm import CohereLLM
2121
from .mistralai_llm import MistralAILLM
2222
from .ollama_llm import OllamaLLM
@@ -30,6 +30,7 @@
3030
"CohereLLM",
3131
"LLMResponse",
3232
"LLMInterface",
33+
"LLMInterfaceV2",
3334
"OllamaLLM",
3435
"OpenAILLM",
3536
"VertexAILLM",

0 commit comments

Comments
 (0)