-
Notifications
You must be signed in to change notification settings - Fork 173
Support for conversations with message history #234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
ec65a48
Add system_instruction parameter
leila-messallem 0537565
Add chat_history parameter
leila-messallem 2e7e8bf
Add missing doc strings
leila-messallem 3c19041
Open AI
leila-messallem 226c08b
Add a summary of the chat history to the query embedding
leila-messallem 6d101dd
Anthropic
leila-messallem d8f3948
Change return type of Anthropic get_messages()
leila-messallem b8910df
Cohere
leila-messallem 72f4de5
Mistral
leila-messallem 5720a4b
VertexAI
leila-messallem 615cea6
Merge branch 'main' into chat-history
leila-messallem 597eff1
Formatting
leila-messallem f2792ff
Merge branch 'chat-history' of github.com:leila-messallem/neo4j-graph…
leila-messallem 6288907
Fix mypy errors
leila-messallem a362fd3
Ollama
leila-messallem 5bb56f6
Override of the system message
leila-messallem 6aea7fa
Use TYPE_CHECKING for dev dependencies
leila-messallem 07038dd
Formatting
leila-messallem 37225fd
Rename `chat_history` to `message_history`
leila-messallem abef33c
Use BaseMessage class type
leila-messallem d7df9e8
System instruction override
leila-messallem a749a9e
Merge branch 'main' into chat-history
leila-messallem 819179e
Revert BaseMessage class type
leila-messallem 2143973
Fix mypy errors
leila-messallem 775447f
Update tests
leila-messallem 17db6b1
Fix ollama NameError
leila-messallem 3c55d3f
Fix NameError in unit tests
leila-messallem d5a287b
Add TypeDict `LLMMessage`
leila-messallem bd34e1a
Simplify the retriever prompt
leila-messallem 23a8001
Fix E2E tests
leila-messallem fa12a9f
Unit tests for the system instruction override
leila-messallem f5a9833
Move and rename the prompts
leila-messallem 81f7ff4
Update changelog
leila-messallem a15a514
Add missing parameter in example
leila-messallem 7557b07
Add LLMMessage to the docs
leila-messallem 717be1c
Update docs README
leila-messallem File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| from neo4j_graphrag.generation.prompts import RagTemplate | ||
| from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel | ||
| from neo4j_graphrag.llm import LLMInterface | ||
| from neo4j_graphrag.llm.types import LLMMessage | ||
| from neo4j_graphrag.retrievers.base import Retriever | ||
| from neo4j_graphrag.types import RetrieverResult | ||
|
|
||
|
|
@@ -83,6 +84,7 @@ def __init__( | |
| def search( | ||
| self, | ||
| query_text: str = "", | ||
| message_history: Optional[list[LLMMessage]] = None, | ||
| examples: str = "", | ||
| retriever_config: Optional[dict[str, Any]] = None, | ||
| return_context: bool | None = None, | ||
|
|
@@ -99,14 +101,15 @@ def search( | |
|
|
||
|
|
||
| Args: | ||
| query_text (str): The user question | ||
| query_text (str): The user question. | ||
| message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned. | ||
| examples (str): Examples added to the LLM prompt. | ||
| retriever_config (Optional[dict]): Parameters passed to the retriever | ||
| retriever_config (Optional[dict]): Parameters passed to the retriever. | ||
| search method; e.g.: top_k | ||
| return_context (bool): Whether to append the retriever result to the final result (default: False) | ||
| return_context (bool): Whether to append the retriever result to the final result (default: False). | ||
|
|
||
| Returns: | ||
| RagResultModel: The LLM-generated answer | ||
| RagResultModel: The LLM-generated answer. | ||
|
|
||
| """ | ||
| if return_context is None: | ||
|
|
@@ -124,18 +127,54 @@ def search( | |
| ) | ||
| except ValidationError as e: | ||
| raise SearchValidationError(e.errors()) | ||
| query_text = validated_data.query_text | ||
| query = self.build_query(validated_data.query_text, message_history) | ||
| retriever_result: RetrieverResult = self.retriever.search( | ||
| query_text=query_text, **validated_data.retriever_config | ||
| query_text=query, **validated_data.retriever_config | ||
| ) | ||
| context = "\n".join(item.content for item in retriever_result.items) | ||
| prompt = self.prompt_template.format( | ||
| query_text=query_text, context=context, examples=validated_data.examples | ||
| ) | ||
| logger.debug(f"RAG: retriever_result={retriever_result}") | ||
| logger.debug(f"RAG: prompt={prompt}") | ||
| answer = self.llm.invoke(prompt) | ||
| answer = self.llm.invoke(prompt, message_history) | ||
| result: dict[str, Any] = {"answer": answer.content} | ||
| if return_context: | ||
| result["retriever_result"] = retriever_result | ||
| return RagResultModel(**result) | ||
|
|
||
| def build_query( | ||
| self, query_text: str, message_history: Optional[list[LLMMessage]] = None | ||
| ) -> str: | ||
| summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words." | ||
| if message_history: | ||
| summarization_prompt = self.chat_summary_prompt( | ||
| message_history=message_history | ||
| ) | ||
| summary = self.llm.invoke( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering if we should allow the user to use a different LLM for summarization. I'm thinking users might want to use a "small" LLM for this simple task, and use a "better" one for the Q&A part. But we can leave it for a later improvement. |
||
| input=summarization_prompt, | ||
| system_instruction=summary_system_message, | ||
| ).content | ||
| return self.conversation_prompt(summary=summary, current_query=query_text) | ||
| return query_text | ||
|
|
||
| def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str: | ||
| message_list = [ | ||
| ": ".join([f"{value}" for _, value in message.items()]) | ||
| for message in message_history | ||
| ] | ||
| history = "\n".join(message_list) | ||
| return f""" | ||
| Summarize the message history: | ||
|
|
||
| {history} | ||
| """ | ||
|
|
||
| def conversation_prompt(self, summary: str, current_query: str) -> str: | ||
| return f""" | ||
| Message Summary: | ||
| {summary} | ||
|
|
||
| Current Query: | ||
| {current_query} | ||
| """ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.