Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e9712a9
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 15, 2024
b52c45e
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 16, 2024
84c1780
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 17, 2024
47d4782
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 21, 2024
bc7a2f9
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 22, 2024
a945284
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 22, 2024
4e13c23
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 23, 2024
5367bed
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 24, 2024
21d1223
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 25, 2024
3329cd7
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 25, 2024
d8f6364
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 28, 2024
4cec2f3
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 4, 2024
4445b49
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 5, 2024
939b18c
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 18, 2024
9a755bb
Use self.neo4j_database for all queries in Neo4jWriter
stellasia Nov 22, 2024
53d162c
Make sure all execute_query can be run against a custom database
stellasia Nov 22, 2024
dadcab1
Update CHANGELOG
stellasia Nov 22, 2024
ff2c358
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python…
stellasia Nov 22, 2024
54e6cca
Update docstring + update examples not to use undocumented feature fo…
stellasia Nov 22, 2024
af0eacf
Expose neo4j_database in SimpleKGBuilder
stellasia Nov 22, 2024
9fcd247
Update CHANGELOG
stellasia Nov 22, 2024
80718ee
Simplify changelog
stellasia Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@
## Next

### Added
- Introduced optional lexical graph configuration for SimpleKGPipeline, enhancing flexibility in customizing node labels and relationship types in the lexical graph.
- Ability to provide description and list of properties for entities and relations in the SimpleKGPipeline constructor.
- Introduced optional lexical graph configuration for `SimpleKGPipeline`, enhancing flexibility in customizing node labels and relationship types in the lexical graph.
- Introduced optional `neo4j_database` parameter for `SimpleKGPipeline`, `Neo4jChunkReader`and `Text2CypherRetriever`.
- Ability to provide description and list of properties for entities and relations in the `SimpleKGPipeline` constructor.

### Fixed
- `neo4j_database` parameter is now used for all queries in the `Neo4jWriter`.

### Changed
- Updated all examples to use `neo4j_database` parameter instead of an undocumented neo4j driver constructor.

## 1.2.0

Expand Down
3 changes: 2 additions & 1 deletion examples/build_graph/simple_kg_builder_from_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ async def define_and_run_pipeline(
entities=ENTITIES,
relations=RELATIONS,
potential_schema=POTENTIAL_SCHEMA,
neo4j_database=DATABASE,
)
return await kg_builder.run_async(file_path=str(file_path))

Expand All @@ -62,7 +63,7 @@ async def main() -> PipelineResult:
"response_format": {"type": "json_object"},
},
)
with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
res = await define_and_run_pipeline(driver, llm)
await llm.async_client.close()
return res
Expand Down
5 changes: 3 additions & 2 deletions examples/build_graph/simple_kg_builder_from_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Neo4j db infos
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
DATABASE = "neo4j"
DATABASE = "newdb"

# Text to process
TEXT = """The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House Atreides,
Expand Down Expand Up @@ -67,6 +67,7 @@ async def define_and_run_pipeline(
relations=RELATIONS,
potential_schema=POTENTIAL_SCHEMA,
from_pdf=False,
neo4j_database=DATABASE,
)
return await kg_builder.run_async(text=TEXT)

Expand All @@ -79,7 +80,7 @@ async def main() -> PipelineResult:
"response_format": {"type": "json_object"},
},
)
with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
res = await define_and_run_pipeline(driver, llm)
await llm.async_client.close()
return res
Expand Down
2 changes: 1 addition & 1 deletion examples/customize/answer/custom_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
driver = neo4j.GraphDatabase.driver(
URI,
auth=AUTH,
database=DATABASE,
)

embedder = OpenAIEmbeddings()
Expand All @@ -33,6 +32,7 @@
index_name=INDEX,
retrieval_query="WITH node, score RETURN node.title as title, node.plot as plot",
embedder=embedder,
neo4j_database=DATABASE,
)

llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
Expand Down
2 changes: 1 addition & 1 deletion examples/customize/answer/langchain_compatiblity.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
driver = neo4j.GraphDatabase.driver(
URI,
auth=AUTH,
database=DATABASE,
)

embedder = OpenAIEmbeddings(model="text-embedding-ada-002")
Expand All @@ -31,6 +30,7 @@
index_name=INDEX,
retrieval_query="WITH node, score RETURN node.title as title, node.plot as plot",
embedder=embedder, # type: ignore[arg-type, unused-ignore]
neo4j_database=DATABASE,
)

llm = ChatOpenAI(model="gpt-4o", temperature=0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def my_result_formatter(record: neo4j.Record) -> RetrieverResultItem:
)


with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
# Initialize the retriever
retriever = VectorCypherRetriever(
driver=driver,
Expand All @@ -48,7 +48,7 @@ def my_result_formatter(record: neo4j.Record) -> RetrieverResultItem:
retrieval_query=RETRIEVAL_QUERY,
result_formatter=my_result_formatter,
# optionally, set neo4j database
# neo4j_database="neo4j",
neo4j_database=DATABASE,
)

# Perform the similarity search for a text query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


# Connect to Neo4j database
driver = neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
driver = neo4j.GraphDatabase.driver(URI, auth=AUTH)


query_text = "Find a movie about astronauts"
Expand All @@ -52,6 +52,7 @@
index_name=INDEX_NAME,
embedder=OpenAIEmbeddings(),
return_properties=["title", "plot"],
neo4j_database=DATABASE,
)
print(retriever.search(query_text=query_text, top_k=top_k_results))
print()
Expand Down
2 changes: 1 addition & 1 deletion examples/question_answering/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem:
driver = neo4j.GraphDatabase.driver(
URI,
auth=AUTH,
database=DATABASE,
)

embedder = OpenAIEmbeddings()
Expand All @@ -46,6 +45,7 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem:
retrieval_query="with node, score return node.title as title, node.plot as plot",
result_formatter=formatter,
embedder=embedder,
neo4j_database=DATABASE,
)

llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
Expand Down
4 changes: 2 additions & 2 deletions examples/retrieve/hybrid_cypher_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# the name of all actors starring in that movie
RETRIEVAL_QUERY = " MATCH (node)<-[:ACTED_IN]-(p:Person) RETURN node.title as movieTitle, node.plot as moviePlot, collect(p.name) as actors, score as similarityScore"

with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
# Initialize the retriever
retriever = HybridCypherRetriever(
driver=driver,
Expand All @@ -37,7 +37,7 @@
# (see corresponding example in 'customize' directory)
# result_formatter=None,
# optionally, set neo4j database
# neo4j_database="neo4j",
neo4j_database=DATABASE,
)

# Perform the similarity search for a text query
Expand Down
4 changes: 2 additions & 2 deletions examples/retrieve/hybrid_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
FULLTEXT_INDEX_NAME = "movieFulltext"


with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
# Initialize the retriever
retriever = HybridRetriever(
driver=driver,
Expand All @@ -31,7 +31,7 @@
# (see corresponding example in 'customize' directory)
# result_formatter=None,
# optionally, set neo4j database
# neo4j_database="neo4j",
neo4j_database=DATABASE,
)

# Perform the similarity search for a text query
Expand Down
4 changes: 2 additions & 2 deletions examples/retrieve/similarity_search_for_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
INDEX_NAME = "moviePlotsEmbedding"


with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
# Initialize the retriever
retriever = VectorRetriever(
driver=driver,
Expand All @@ -29,7 +29,7 @@
# (see corresponding example in 'customize' directory)
# result_formatter=None,
# optionally, set neo4j database
# neo4j_database="neo4j",
neo4j_database=DATABASE,
)

# Perform the similarity search for a text query
Expand Down
3 changes: 2 additions & 1 deletion examples/retrieve/similarity_search_for_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
INDEX_NAME = "moviePlotsEmbedding"


with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
# Initialize the retriever
retriever = VectorRetriever(
driver=driver,
index_name=INDEX_NAME,
neo4j_database=DATABASE,
)

# Perform the similarity search for a vector query
Expand Down
1 change: 1 addition & 0 deletions examples/retrieve/text2cypher_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
# optionally, you can also provide your own prompt
# for the text2Cypher generation step
# custom_prompt="",
neo4j_database=DATABASE,
)

# Generate a Cypher query using the LLM, send it to the Neo4j database, and return the results
Expand Down
4 changes: 2 additions & 2 deletions examples/retrieve/vector_cypher_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# the name of all actors starring in that movie
RETRIEVAL_QUERY = " MATCH (node)<-[:ACTED_IN]-(p:Person) RETURN node.title as movieTitle, node.plot as moviePlot, collect(p.name) as actors, score as similarityScore"

with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
# Initialize the retriever
retriever = VectorCypherRetriever(
driver=driver,
Expand All @@ -34,7 +34,7 @@
# (see corresponding example in 'customize' directory)
# result_formatter=None,
# optionally, set neo4j database
# neo4j_database="neo4j",
neo4j_database=DATABASE,
)

# Perform the similarity search for a text query
Expand Down
31 changes: 22 additions & 9 deletions src/neo4j_graphrag/experimental/components/kg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class Neo4jWriter(KGWriter):

Args:
driver (neo4j.driver): The Neo4j driver to connect to the database.
neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
batch_size (int): The number of nodes or relationships to write to the database in a batch. Defaults to 1000.

Example:
Expand All @@ -99,7 +99,7 @@ class Neo4jWriter(KGWriter):
AUTH = ("neo4j", "password")
DATABASE = "neo4j"

driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
driver = GraphDatabase.driver(URI, auth=AUTH)
writer = Neo4jWriter(driver=driver, neo4j_database=DATABASE)

pipeline = Pipeline()
Expand All @@ -119,10 +119,11 @@ def __init__(
self.is_version_5_23_or_above = self._check_if_version_5_23_or_above()

def _db_setup(self) -> None:
# create index on __Entity__.id
# create index on __KGBuilder__.id
# used when creating the relationships
self.driver.execute_query(
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)"
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)",
database_=self.neo4j_database,
)

@staticmethod
Expand Down Expand Up @@ -150,10 +151,16 @@ def _upsert_nodes(
parameters = {"rows": self._nodes_to_rows(nodes, lexical_graph_config)}
if self.is_version_5_23_or_above:
self.driver.execute_query(
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE,
parameters_=parameters,
database_=self.neo4j_database,
)
else:
self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters)
self.driver.execute_query(
UPSERT_NODE_QUERY,
parameters_=parameters,
database_=self.neo4j_database,
)

def _get_version(self) -> tuple[int, ...]:
records, _, _ = self.driver.execute_query(
Expand Down Expand Up @@ -187,10 +194,16 @@ def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
parameters = {"rows": [rel.model_dump() for rel in rels]}
if self.is_version_5_23_or_above:
self.driver.execute_query(
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE,
parameters_=parameters,
database_=self.neo4j_database,
)
else:
self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters)
self.driver.execute_query(
UPSERT_RELATIONSHIP_QUERY,
parameters_=parameters,
database_=self.neo4j_database,
)

@validate_call
async def run(
Expand All @@ -202,7 +215,7 @@ async def run(

Args:
graph (Neo4jGraph): The knowledge graph to upsert into the database.
lexical_graph_config (LexicalGraphConfig):
lexical_graph_config (LexicalGraphConfig): Node labels and relationship types for the lexical graph.
"""
try:
self._db_setup()
Expand Down
38 changes: 37 additions & 1 deletion src/neo4j_graphrag/experimental/components/neo4j_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations

from typing import Optional

import neo4j
from pydantic import validate_call

Expand All @@ -26,13 +28,39 @@


class Neo4jChunkReader(Component):
"""Reads text chunks from a Neo4j database.

Args:
driver (neo4j.driver): The Neo4j driver to connect to the database.
fetch_embeddings (bool): If True, the embedding property is also returned. Default to False.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).

Example:

.. code-block:: python

from neo4j import GraphDatabase
from neo4j_graphrag.experimental.components.neo4j_reader import Neo4jChunkReader

URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
DATABASE = "neo4j"

driver = GraphDatabase.driver(URI, auth=AUTH)
reader = Neo4jChunkReader(driver=driver, neo4j_database=DATABASE)
await reader.run()

"""

def __init__(
self,
driver: neo4j.Driver,
fetch_embeddings: bool = False,
neo4j_database: Optional[str] = None,
):
self.driver = driver
self.fetch_embeddings = fetch_embeddings
self.neo4j_database = neo4j_database

def _get_query(
self,
Expand All @@ -56,12 +84,20 @@ async def run(
self,
lexical_graph_config: LexicalGraphConfig = LexicalGraphConfig(),
) -> TextChunks:
"""Reads text chunks from a Neo4j database.

Args:
lexical_graph_config (LexicalGraphConfig): Node labels and relationship types for the lexical graph.
"""
query = self._get_query(
lexical_graph_config.chunk_node_label,
lexical_graph_config.chunk_index_property,
lexical_graph_config.chunk_embedding_property,
)
result, _, _ = self.driver.execute_query(query)
result, _, _ = self.driver.execute_query(
query,
database_=self.neo4j_database,
)
chunks = []
for record in result:
chunk = record.get("chunk")
Expand Down
Loading