Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e9712a9
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 15, 2024
b52c45e
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 16, 2024
84c1780
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 17, 2024
47d4782
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 21, 2024
bc7a2f9
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 22, 2024
a945284
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 22, 2024
4e13c23
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 23, 2024
5367bed
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 24, 2024
21d1223
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 25, 2024
3329cd7
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 25, 2024
d8f6364
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 28, 2024
4cec2f3
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 4, 2024
4445b49
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 5, 2024
939b18c
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 18, 2024
1104519
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 22, 2024
1893b85
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 25, 2024
6e4ebda
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 28, 2024
8db7f01
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Dec 9, 2024
085cf10
Doc + bug fix
stellasia Dec 10, 2024
c23616b
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python…
stellasia Dec 12, 2024
4cb7842
Do not change the behavior, just document they said
stellasia Dec 12, 2024
a615399
Use same order for patched functions and check order of mocked object
stellasia Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,21 @@ RagTemplate

.. autoclass:: neo4j_graphrag.generation.prompts.RagTemplate
:members:
:exclude-members: format

ERExtractionTemplate
--------------------

.. autoclass:: neo4j_graphrag.generation.prompts.ERExtractionTemplate
:members:
:exclude-members: format

Text2CypherTemplate
--------------------

.. autoclass:: neo4j_graphrag.generation.prompts.Text2CypherTemplate
:members:
:exclude-members: format


****
Expand Down
5 changes: 3 additions & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ are listed in [the last section of this file](#customize).

- [Control result format for VectorRetriever](customize/retrievers/result_formatter_vector_retriever.py)
- [Control result format for VectorCypherRetriever](customize/retrievers/result_formatter_vector_cypher_retriever.py)

- [Use pre-filters](customize/retrievers/use_pre_filters.py)
- [Text2Cypher: use a custom prompt](customize/retrievers/text2cypher_custom_prompt.py)

### LLMs

Expand All @@ -74,7 +75,7 @@ are listed in [the last section of this file](#customize).

### Prompts

- [Using a custom prompt](old/graphrag_custom_prompt.py)
- [Using a custom prompt for RAG](customize/answer/custom_prompt.py)


### Embedders
Expand Down
76 changes: 76 additions & 0 deletions examples/customize/retrievers/text2cypher_custom_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""The example shows how to provide a custom prompt to Text2CypherRetriever.

Example using the OpenAILLM, hence the OPENAI_API_KEY needs to be set in the
environment for this example to run.
"""

import neo4j
from neo4j_graphrag.llm import OpenAILLM
from neo4j_graphrag.retrievers import Text2CypherRetriever
from neo4j_graphrag.schema import get_schema

# Define database credentials
URI = "neo4j+s://demo.neo4jlabs.com"
AUTH = ("recommendations", "recommendations")
DATABASE = "recommendations"

# Create LLM object
llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})

# (Optional) Specify your own Neo4j schema
# (also see get_structured_schema and get_schema functions)
neo4j_schema = """
Node properties:
User {name: STRING}
Person {name: STRING, born: INTEGER}
Movie {tagline: STRING, title: STRING, released: INTEGER}
Relationship properties:
ACTED_IN {roles: LIST}
DIRECTED {}
REVIEWED {summary: STRING, rating: INTEGER}
The relationships:
(:Person)-[:ACTED_IN]->(:Movie)
(:Person)-[:DIRECTED]->(:Movie)
(:User)-[:REVIEWED]->(:Movie)
"""

prompt = """Task: Generate a Cypher statement for querying a Neo4j graph database from a user input.

Do not use any properties or relationships not included in the schema.
Do not include triple backticks ``` or any additional text except the generated Cypher statement in your response.

Always filter movies that have not already been reviewed by the user with name: '{user_name}' using for instance:
(m:Movie)<-[:REVIEWED]-(:User {{name: <the_user_name>}})

Schema:
{schema}

Input:
{query_text}

Cypher query:
"""

with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
# Initialize the retriever
retriever = Text2CypherRetriever(
driver=driver,
llm=llm,
neo4j_schema=neo4j_schema,
# here we provide a custom prompt
custom_prompt=prompt,
neo4j_database=DATABASE,
)

# Generate a Cypher query using the LLM, send it to the Neo4j database, and return the results
query_text = "Which movies did Hugo Weaving star in?"
print(
retriever.search(
query_text=query_text,
prompt_params={
# you have to specify all placeholder except the {query_text} one
"schema": get_schema(driver),
"user_name": "the user asking question",
},
)
)
4 changes: 1 addition & 3 deletions examples/retrieve/text2cypher_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@
Movie {tagline: STRING, title: STRING, released: INTEGER}
Relationship properties:
ACTED_IN {roles: LIST}
DIRECTED {}
REVIEWED {summary: STRING, rating: INTEGER}
The relationships:
(:Person)-[:ACTED_IN]->(:Movie)
(:Person)-[:DIRECTED]->(:Movie)
(:Person)-[:PRODUCED]->(:Movie)
(:Person)-[:WROTE]->(:Movie)
(:Person)-[:FOLLOWS]->(:Person)
(:Person)-[:REVIEWED]->(:Movie)
"""

Expand Down
38 changes: 18 additions & 20 deletions src/neo4j_graphrag/retrievers/text2cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Text2CypherRetriever(Retriever):
"""
Allows for the retrieval of records from a Neo4j database using natural language.
Converts a user's natural language query to a Cypher query using an LLM,
then retrieves records from a Neo4j database using the generated Cypher query
then retrieves records from a Neo4j database using the generated Cypher query.

Args:
driver (neo4j.Driver): The Neo4j Python driver.
Expand Down Expand Up @@ -98,23 +98,23 @@ def __init__(
self.examples = validated_data.examples
self.result_formatter = validated_data.result_formatter
self.custom_prompt = validated_data.custom_prompt
try:
if validated_data.custom_prompt:
neo4j_schema = ""
else:
if (
not validated_data.custom_prompt
): # don't need schema for a custom prompt
self.neo4j_schema = (
validated_data.neo4j_schema_model.neo4j_schema
if validated_data.neo4j_schema_model
else get_schema(validated_data.driver_model.driver)
)
validated_data.neo4j_schema_model
and validated_data.neo4j_schema_model.neo4j_schema
):
neo4j_schema = validated_data.neo4j_schema_model.neo4j_schema
else:
self.neo4j_schema = ""

except (Neo4jError, DriverError) as e:
error_message = getattr(e, "message", str(e))
raise SchemaFetchError(
f"Failed to fetch schema for Text2CypherRetriever: {error_message}"
) from e
try:
neo4j_schema = get_schema(validated_data.driver_model.driver)
except (Neo4jError, DriverError) as e:
error_message = getattr(e, "message", str(e))
raise SchemaFetchError(
f"Failed to fetch schema for Text2CypherRetriever: {error_message}"
) from e
self.neo4j_schema = neo4j_schema

def get_search_results(
self, query_text: str, prompt_params: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -142,12 +142,10 @@ def get_search_results(

if prompt_params is not None:
# parse the schema and examples inputs
examples_to_use = prompt_params.get("examples") or (
examples_to_use = prompt_params.pop("examples", None) or (
"\n".join(self.examples) if self.examples else ""
)
schema_to_use = prompt_params.get("schema") or self.neo4j_schema
prompt_params.pop("examples", None)
prompt_params.pop("schema", None)
schema_to_use = prompt_params.pop("schema", None) or self.neo4j_schema
else:
examples_to_use = "\n".join(self.examples) if self.examples else ""
schema_to_use = self.neo4j_schema
Expand Down
35 changes: 32 additions & 3 deletions tests/unit/retrievers/test_text2cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from neo4j.exceptions import CypherSyntaxError, Neo4jError
from neo4j_graphrag.exceptions import (
RetrieverInitializationError,
SchemaFetchError,
SearchValidationError,
Text2CypherRetrievalError,
)
Expand All @@ -39,8 +40,8 @@ def test_t2c_retriever_initialization(driver: MagicMock, llm: MagicMock) -> None
@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
@patch("neo4j_graphrag.retrievers.text2cypher.get_schema")
def test_t2c_retriever_schema_retrieval(
_verify_version_mock: MagicMock,
get_schema_mock: MagicMock,
_verify_version_mock: MagicMock,
driver: MagicMock,
llm: MagicMock,
) -> None:
Expand All @@ -51,13 +52,13 @@ def test_t2c_retriever_schema_retrieval(
@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
@patch("neo4j_graphrag.retrievers.text2cypher.get_schema")
def test_t2c_retriever_schema_retrieval_failure(
_verify_version_mock: MagicMock,
get_schema_mock: MagicMock,
_verify_version_mock: MagicMock,
driver: MagicMock,
llm: MagicMock,
) -> None:
get_schema_mock.side_effect = Neo4jError
with pytest.raises(Neo4jError):
with pytest.raises(SchemaFetchError):
Text2CypherRetriever(driver, llm)


Expand Down Expand Up @@ -310,3 +311,31 @@ def test_t2c_retriever_with_custom_prompt_bad_prompt_params(
llm.invoke.assert_called_once_with(
"""This is a custom prompt. test ['example A', 'example B']"""
)


@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
@patch("neo4j_graphrag.retrievers.text2cypher.get_schema")
def test_t2c_retriever_with_custom_prompt_and_schema(
get_schema_mock: MagicMock,
_verify_version_mock: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
) -> None:
prompt = "This is a custom prompt. {query_text} {schema}"
query = "test"

driver.execute_query.return_value = (
[neo4j_record],
None,
None,
)

retriever = Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=prompt)
retriever.search(
query_text=query,
prompt_params={},
)

get_schema_mock.assert_not_called()
llm.invoke.assert_called_once_with("""This is a custom prompt. test """)
Loading