diff --git a/CHANGELOG.md b/CHANGELOG.md index f82e76854..c99d49b84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,8 +3,15 @@ ## Next ### Added -- Introduced optional lexical graph configuration for SimpleKGPipeline, enhancing flexibility in customizing node labels and relationship types in the lexical graph. -- Ability to provide description and list of properties for entities and relations in the SimpleKGPipeline constructor. +- Introduced optional lexical graph configuration for `SimpleKGPipeline`, enhancing flexibility in customizing node labels and relationship types in the lexical graph. +- Introduced optional `neo4j_database` parameter for `SimpleKGPipeline`, `Neo4jChunkReader`and `Text2CypherRetriever`. +- Ability to provide description and list of properties for entities and relations in the `SimpleKGPipeline` constructor. + +### Fixed +- `neo4j_database` parameter is now used for all queries in the `Neo4jWriter`. + +### Changed +- Updated all examples to use `neo4j_database` parameter instead of an undocumented neo4j driver constructor. ## 1.2.0 diff --git a/examples/build_graph/simple_kg_builder_from_pdf.py b/examples/build_graph/simple_kg_builder_from_pdf.py index 7b33b256e..fd79a1d63 100644 --- a/examples/build_graph/simple_kg_builder_from_pdf.py +++ b/examples/build_graph/simple_kg_builder_from_pdf.py @@ -50,6 +50,7 @@ async def define_and_run_pipeline( entities=ENTITIES, relations=RELATIONS, potential_schema=POTENTIAL_SCHEMA, + neo4j_database=DATABASE, ) return await kg_builder.run_async(file_path=str(file_path)) @@ -62,7 +63,7 @@ async def main() -> PipelineResult: "response_format": {"type": "json_object"}, }, ) - with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver: + with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver: res = await define_and_run_pipeline(driver, llm) await llm.async_client.close() return res diff --git a/examples/build_graph/simple_kg_builder_from_text.py b/examples/build_graph/simple_kg_builder_from_text.py index d8d83ed72..2dec770e5 100644 --- a/examples/build_graph/simple_kg_builder_from_text.py +++ b/examples/build_graph/simple_kg_builder_from_text.py @@ -21,7 +21,7 @@ # Neo4j db infos URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") -DATABASE = "neo4j" +DATABASE = "newdb" # Text to process TEXT = """The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House Atreides, @@ -67,6 +67,7 @@ async def define_and_run_pipeline( relations=RELATIONS, potential_schema=POTENTIAL_SCHEMA, from_pdf=False, + neo4j_database=DATABASE, ) return await kg_builder.run_async(text=TEXT) @@ -79,7 +80,7 @@ async def main() -> PipelineResult: "response_format": {"type": "json_object"}, }, ) - with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver: + with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver: res = await define_and_run_pipeline(driver, llm) await llm.async_client.close() return res diff --git a/examples/customize/answer/custom_prompt.py b/examples/customize/answer/custom_prompt.py index 4089c7031..a31c1a941 100644 --- a/examples/customize/answer/custom_prompt.py +++ b/examples/customize/answer/custom_prompt.py @@ -23,7 +23,6 @@ driver = neo4j.GraphDatabase.driver( URI, auth=AUTH, - database=DATABASE, ) embedder = OpenAIEmbeddings() @@ -33,6 +32,7 @@ index_name=INDEX, retrieval_query="WITH node, score RETURN node.title as title, node.plot as plot", embedder=embedder, + neo4j_database=DATABASE, ) llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0}) diff --git a/examples/customize/answer/langchain_compatiblity.py b/examples/customize/answer/langchain_compatiblity.py index 858dd12e7..f021b8cd3 100644 --- a/examples/customize/answer/langchain_compatiblity.py +++ b/examples/customize/answer/langchain_compatiblity.py @@ -21,7 +21,6 @@ driver = neo4j.GraphDatabase.driver( URI, auth=AUTH, - database=DATABASE, ) embedder = OpenAIEmbeddings(model="text-embedding-ada-002") @@ -31,6 +30,7 @@ index_name=INDEX, retrieval_query="WITH node, score RETURN node.title as title, node.plot as plot", embedder=embedder, # type: ignore[arg-type, unused-ignore] + neo4j_database=DATABASE, ) llm = ChatOpenAI(model="gpt-4o", temperature=0) diff --git a/examples/customize/retrievers/result_formatter_vector_cypher_retriever.py b/examples/customize/retrievers/result_formatter_vector_cypher_retriever.py index 2a8f520db..42dda21f6 100644 --- a/examples/customize/retrievers/result_formatter_vector_cypher_retriever.py +++ b/examples/customize/retrievers/result_formatter_vector_cypher_retriever.py @@ -38,7 +38,7 @@ def my_result_formatter(record: neo4j.Record) -> RetrieverResultItem: ) -with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver: +with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver: # Initialize the retriever retriever = VectorCypherRetriever( driver=driver, @@ -48,7 +48,7 @@ def my_result_formatter(record: neo4j.Record) -> RetrieverResultItem: retrieval_query=RETRIEVAL_QUERY, result_formatter=my_result_formatter, # optionally, set neo4j database - # neo4j_database="neo4j", + neo4j_database=DATABASE, ) # Perform the similarity search for a text query diff --git a/examples/customize/retrievers/result_formatter_vector_retriever.py b/examples/customize/retrievers/result_formatter_vector_retriever.py index e4c7448f3..77734ede7 100644 --- a/examples/customize/retrievers/result_formatter_vector_retriever.py +++ b/examples/customize/retrievers/result_formatter_vector_retriever.py @@ -35,7 +35,7 @@ # Connect to Neo4j database -driver = neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) +driver = neo4j.GraphDatabase.driver(URI, auth=AUTH) query_text = "Find a movie about astronauts" @@ -52,6 +52,7 @@ index_name=INDEX_NAME, embedder=OpenAIEmbeddings(), return_properties=["title", "plot"], + neo4j_database=DATABASE, ) print(retriever.search(query_text=query_text, top_k=top_k_results)) print() diff --git a/examples/question_answering/graphrag.py b/examples/question_answering/graphrag.py index c20f622df..526f83c78 100644 --- a/examples/question_answering/graphrag.py +++ b/examples/question_answering/graphrag.py @@ -35,7 +35,6 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem: driver = neo4j.GraphDatabase.driver( URI, auth=AUTH, - database=DATABASE, ) embedder = OpenAIEmbeddings() @@ -46,6 +45,7 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem: retrieval_query="with node, score return node.title as title, node.plot as plot", result_formatter=formatter, embedder=embedder, + neo4j_database=DATABASE, ) llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0}) diff --git a/examples/retrieve/hybrid_cypher_retriever.py b/examples/retrieve/hybrid_cypher_retriever.py index 232e0c439..f29a214e5 100644 --- a/examples/retrieve/hybrid_cypher_retriever.py +++ b/examples/retrieve/hybrid_cypher_retriever.py @@ -24,7 +24,7 @@ # the name of all actors starring in that movie RETRIEVAL_QUERY = " MATCH (node)<-[:ACTED_IN]-(p:Person) RETURN node.title as movieTitle, node.plot as moviePlot, collect(p.name) as actors, score as similarityScore" -with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver: +with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver: # Initialize the retriever retriever = HybridCypherRetriever( driver=driver, @@ -37,7 +37,7 @@ # (see corresponding example in 'customize' directory) # result_formatter=None, # optionally, set neo4j database - # neo4j_database="neo4j", + neo4j_database=DATABASE, ) # Perform the similarity search for a text query diff --git a/examples/retrieve/hybrid_retriever.py b/examples/retrieve/hybrid_retriever.py index df9e1d295..51e5e6686 100644 --- a/examples/retrieve/hybrid_retriever.py +++ b/examples/retrieve/hybrid_retriever.py @@ -18,7 +18,7 @@ FULLTEXT_INDEX_NAME = "movieFulltext" -with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver: +with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver: # Initialize the retriever retriever = HybridRetriever( driver=driver, @@ -31,7 +31,7 @@ # (see corresponding example in 'customize' directory) # result_formatter=None, # optionally, set neo4j database - # neo4j_database="neo4j", + neo4j_database=DATABASE, ) # Perform the similarity search for a text query diff --git a/examples/retrieve/similarity_search_for_text.py b/examples/retrieve/similarity_search_for_text.py index 32ddda5de..bd4aaf960 100644 --- a/examples/retrieve/similarity_search_for_text.py +++ b/examples/retrieve/similarity_search_for_text.py @@ -17,7 +17,7 @@ INDEX_NAME = "moviePlotsEmbedding" -with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver: +with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver: # Initialize the retriever retriever = VectorRetriever( driver=driver, @@ -29,7 +29,7 @@ # (see corresponding example in 'customize' directory) # result_formatter=None, # optionally, set neo4j database - # neo4j_database="neo4j", + neo4j_database=DATABASE, ) # Perform the similarity search for a text query diff --git a/examples/retrieve/similarity_search_for_vector.py b/examples/retrieve/similarity_search_for_vector.py index 43b38a0e7..9612f6efb 100644 --- a/examples/retrieve/similarity_search_for_vector.py +++ b/examples/retrieve/similarity_search_for_vector.py @@ -17,11 +17,12 @@ INDEX_NAME = "moviePlotsEmbedding" -with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver: +with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver: # Initialize the retriever retriever = VectorRetriever( driver=driver, index_name=INDEX_NAME, + neo4j_database=DATABASE, ) # Perform the similarity search for a vector query diff --git a/examples/retrieve/text2cypher_search.py b/examples/retrieve/text2cypher_search.py index 16326ba17..e17cb65d1 100644 --- a/examples/retrieve/text2cypher_search.py +++ b/examples/retrieve/text2cypher_search.py @@ -47,6 +47,7 @@ # optionally, you can also provide your own prompt # for the text2Cypher generation step # custom_prompt="", + neo4j_database=DATABASE, ) # Generate a Cypher query using the LLM, send it to the Neo4j database, and return the results diff --git a/examples/retrieve/vector_cypher_retriever.py b/examples/retrieve/vector_cypher_retriever.py index d4b98334f..e8dab571c 100644 --- a/examples/retrieve/vector_cypher_retriever.py +++ b/examples/retrieve/vector_cypher_retriever.py @@ -22,7 +22,7 @@ # the name of all actors starring in that movie RETRIEVAL_QUERY = " MATCH (node)<-[:ACTED_IN]-(p:Person) RETURN node.title as movieTitle, node.plot as moviePlot, collect(p.name) as actors, score as similarityScore" -with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver: +with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver: # Initialize the retriever retriever = VectorCypherRetriever( driver=driver, @@ -34,7 +34,7 @@ # (see corresponding example in 'customize' directory) # result_formatter=None, # optionally, set neo4j database - # neo4j_database="neo4j", + neo4j_database=DATABASE, ) # Perform the similarity search for a text query diff --git a/src/neo4j_graphrag/experimental/components/kg_writer.py b/src/neo4j_graphrag/experimental/components/kg_writer.py index e1316f111..fad24b7a9 100644 --- a/src/neo4j_graphrag/experimental/components/kg_writer.py +++ b/src/neo4j_graphrag/experimental/components/kg_writer.py @@ -84,7 +84,7 @@ class Neo4jWriter(KGWriter): Args: driver (neo4j.driver): The Neo4j driver to connect to the database. - neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided. + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). batch_size (int): The number of nodes or relationships to write to the database in a batch. Defaults to 1000. Example: @@ -99,7 +99,7 @@ class Neo4jWriter(KGWriter): AUTH = ("neo4j", "password") DATABASE = "neo4j" - driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) + driver = GraphDatabase.driver(URI, auth=AUTH) writer = Neo4jWriter(driver=driver, neo4j_database=DATABASE) pipeline = Pipeline() @@ -119,10 +119,11 @@ def __init__( self.is_version_5_23_or_above = self._check_if_version_5_23_or_above() def _db_setup(self) -> None: - # create index on __Entity__.id + # create index on __KGBuilder__.id # used when creating the relationships self.driver.execute_query( - "CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)" + "CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)", + database_=self.neo4j_database, ) @staticmethod @@ -150,10 +151,16 @@ def _upsert_nodes( parameters = {"rows": self._nodes_to_rows(nodes, lexical_graph_config)} if self.is_version_5_23_or_above: self.driver.execute_query( - UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters + UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, + parameters_=parameters, + database_=self.neo4j_database, ) else: - self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters) + self.driver.execute_query( + UPSERT_NODE_QUERY, + parameters_=parameters, + database_=self.neo4j_database, + ) def _get_version(self) -> tuple[int, ...]: records, _, _ = self.driver.execute_query( @@ -187,10 +194,16 @@ def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None: parameters = {"rows": [rel.model_dump() for rel in rels]} if self.is_version_5_23_or_above: self.driver.execute_query( - UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters + UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, + parameters_=parameters, + database_=self.neo4j_database, ) else: - self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters) + self.driver.execute_query( + UPSERT_RELATIONSHIP_QUERY, + parameters_=parameters, + database_=self.neo4j_database, + ) @validate_call async def run( @@ -202,7 +215,7 @@ async def run( Args: graph (Neo4jGraph): The knowledge graph to upsert into the database. - lexical_graph_config (LexicalGraphConfig): + lexical_graph_config (LexicalGraphConfig): Node labels and relationship types for the lexical graph. """ try: self._db_setup() diff --git a/src/neo4j_graphrag/experimental/components/neo4j_reader.py b/src/neo4j_graphrag/experimental/components/neo4j_reader.py index 6d384cf8e..8aee5d1cd 100644 --- a/src/neo4j_graphrag/experimental/components/neo4j_reader.py +++ b/src/neo4j_graphrag/experimental/components/neo4j_reader.py @@ -14,6 +14,8 @@ # limitations under the License. from __future__ import annotations +from typing import Optional + import neo4j from pydantic import validate_call @@ -26,13 +28,39 @@ class Neo4jChunkReader(Component): + """Reads text chunks from a Neo4j database. + + Args: + driver (neo4j.driver): The Neo4j driver to connect to the database. + fetch_embeddings (bool): If True, the embedding property is also returned. Default to False. + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). + + Example: + + .. code-block:: python + + from neo4j import GraphDatabase + from neo4j_graphrag.experimental.components.neo4j_reader import Neo4jChunkReader + + URI = "neo4j://localhost:7687" + AUTH = ("neo4j", "password") + DATABASE = "neo4j" + + driver = GraphDatabase.driver(URI, auth=AUTH) + reader = Neo4jChunkReader(driver=driver, neo4j_database=DATABASE) + await reader.run() + + """ + def __init__( self, driver: neo4j.Driver, fetch_embeddings: bool = False, + neo4j_database: Optional[str] = None, ): self.driver = driver self.fetch_embeddings = fetch_embeddings + self.neo4j_database = neo4j_database def _get_query( self, @@ -56,12 +84,20 @@ async def run( self, lexical_graph_config: LexicalGraphConfig = LexicalGraphConfig(), ) -> TextChunks: + """Reads text chunks from a Neo4j database. + + Args: + lexical_graph_config (LexicalGraphConfig): Node labels and relationship types for the lexical graph. + """ query = self._get_query( lexical_graph_config.chunk_node_label, lexical_graph_config.chunk_index_property, lexical_graph_config.chunk_embedding_property, ) - result, _, _ = self.driver.execute_query(query) + result, _, _ = self.driver.execute_query( + query, + database_=self.neo4j_database, + ) chunks = [] for record in result: chunk = record.get("chunk") diff --git a/src/neo4j_graphrag/experimental/components/resolver.py b/src/neo4j_graphrag/experimental/components/resolver.py index d9fa8f994..142166407 100644 --- a/src/neo4j_graphrag/experimental/components/resolver.py +++ b/src/neo4j_graphrag/experimental/components/resolver.py @@ -49,7 +49,7 @@ class SinglePropertyExactMatchResolver(EntityResolver): driver (neo4j.Driver): The Neo4j driver to connect to the database. filter_query (Optional[str]): To reduce the resolution scope, add a Cypher WHERE clause. resolve_property (str): The property that will be compared (default: "name"). If values match exactly, entities are merged. - neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided. + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Example: @@ -62,7 +62,7 @@ class SinglePropertyExactMatchResolver(EntityResolver): AUTH = ("neo4j", "password") DATABASE = "neo4j" - driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) + driver = GraphDatabase.driver(URI, auth=AUTH) resolver = SinglePropertyExactMatchResolver(driver=driver, neo4j_database=DATABASE) await resolver.run() # no expected parameters @@ -77,7 +77,7 @@ def __init__( ) -> None: super().__init__(driver, filter_query) self.resolve_property = resolve_property - self.database = neo4j_database + self.neo4j_database = neo4j_database async def run(self) -> ResolutionStats: """Resolve entities based on the following rule: @@ -93,7 +93,9 @@ async def run(self) -> ResolutionStats: if self.filter_query: match_query += self.filter_query stat_query = f"{match_query} RETURN count(entity) as c" - records, _, _ = self.driver.execute_query(stat_query, database_=self.database) + records, _, _ = self.driver.execute_query( + stat_query, database_=self.neo4j_database + ) number_of_nodes_to_resolve = records[0].get("c") if number_of_nodes_to_resolve == 0: return ResolutionStats( @@ -126,7 +128,7 @@ async def run(self) -> ResolutionStats: "RETURN count(node) as c " ) records, _, _ = self.driver.execute_query( - merge_nodes_query, database_=self.database + merge_nodes_query, database_=self.neo4j_database ) number_of_created_nodes = records[0].get("c") return ResolutionStats( diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 33bf6849f..58868cb36 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -65,6 +65,7 @@ class SimpleKGPipelineConfig(BaseModel): prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() perform_entity_resolution: bool = True lexical_graph_config: Optional[LexicalGraphConfig] = None + neo4j_database: Optional[str] = None model_config = ConfigDict(arbitrary_types_allowed=True) @@ -117,6 +118,7 @@ def __init__( prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate(), perform_entity_resolution: bool = True, lexical_graph_config: Optional[LexicalGraphConfig] = None, + neo4j_database: Optional[str] = None, ): self.potential_schema = potential_schema or [] self.entities = [self.to_schema_entity(e) for e in entities or []] @@ -144,6 +146,7 @@ def __init__( embedder=embedder, perform_entity_resolution=perform_entity_resolution, lexical_graph_config=lexical_graph_config, + neo4j_database=neo4j_database, ) self.from_pdf = config.from_pdf @@ -154,11 +157,14 @@ def __init__( self.on_error = config.on_error self.pdf_loader = config.pdf_loader if pdf_loader is not None else PdfLoader() self.kg_writer = ( - config.kg_writer if kg_writer is not None else Neo4jWriter(driver) + config.kg_writer + if kg_writer is not None + else Neo4jWriter(driver, neo4j_database=config.neo4j_database) ) self.prompt_template = config.prompt_template self.perform_entity_resolution = config.perform_entity_resolution self.lexical_graph_config = config.lexical_graph_config + self.neo4j_database = config.neo4j_database self.pipeline = self._build_pipeline() @@ -233,7 +239,10 @@ def _build_pipeline(self) -> Pipeline: if self.perform_entity_resolution: pipe.add_component( - SinglePropertyExactMatchResolver(self.driver), "resolver" + SinglePropertyExactMatchResolver( + self.driver, neo4j_database=self.neo4j_database + ), + "resolver", ) pipe.connect("writer", "resolver", {}) diff --git a/src/neo4j_graphrag/indexes.py b/src/neo4j_graphrag/indexes.py index cb8101c6f..65d7eae10 100644 --- a/src/neo4j_graphrag/indexes.py +++ b/src/neo4j_graphrag/indexes.py @@ -85,7 +85,7 @@ def create_vector_index( similarity_fn (str): case-insensitive values for the vector similarity function: ``euclidean`` or ``cosine``. fail_if_exists (bool): If True raise an error if the index already exists. Defaults to False. - neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Raises: ValueError: If validation of the input arguments fail. @@ -167,7 +167,7 @@ def create_fulltext_index( label (str): The node label to be indexed. node_properties (list[str]): The node properties to create the fulltext index on. fail_if_exists (bool): If True raise an error if the index already exists. Defaults to False. - neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Raises: ValueError: If validation of the input arguments fail. @@ -229,7 +229,7 @@ def drop_index_if_exists( Args: driver (neo4j.Driver): Neo4j Python driver instance. name (str): The name of the index to delete. - neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Raises: neo4j.exceptions.ClientError: If dropping of index fails. @@ -281,7 +281,7 @@ def upsert_vector( node_id (int): The element id of the node. embedding_property (str): The name of the property to store the vector in. vector (list[float]): The vector to store. - neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Raises: Neo4jInsertionError: If upserting of the vector fails. @@ -337,7 +337,7 @@ def upsert_vector_on_relationship( rel_id (int): The element id of the relationship. embedding_property (str): The name of the property to store the vector in. vector (list[float]): The vector to store. - neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Raises: Neo4jInsertionError: If upserting of the vector fails. @@ -394,7 +394,7 @@ async def async_upsert_vector( node_id (int): The element id of the node. embedding_property (str): The name of the property to store the vector in. vector (list[float]): The vector to store. - neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Raises: Neo4jInsertionError: If upserting of the vector fails. @@ -451,7 +451,7 @@ async def async_upsert_vector_on_relationship( rel_id (int): The element id of the relationship. embedding_property (str): The name of the property to store the vector in. vector (list[float]): The vector to store. - neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Raises: Neo4jInsertionError: If upserting of the vector fails. diff --git a/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py b/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py index f1b357f5c..588791129 100644 --- a/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py +++ b/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py @@ -82,7 +82,7 @@ class PineconeNeo4jRetriever(ExternalRetriever): return_properties (Optional[list[str]]): List of node properties to return. retrieval_query (str): Cypher query that gets appended. result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem. - neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Raises: RetrieverInitializationError: If validation of the input arguments fail. diff --git a/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py b/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py index b76986238..a38b322fd 100644 --- a/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py +++ b/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py @@ -76,7 +76,7 @@ class QdrantNeo4jRetriever(ExternalRetriever): embedder (Optional[Embedder]): Embedder object to embed query text. return_properties (Optional[list[str]]): List of node properties to return. result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem. - neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Raises: RetrieverInitializationError: If validation of the input arguments fail. diff --git a/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py b/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py index 4a777a333..a6f28b7e8 100644 --- a/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py +++ b/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py @@ -80,7 +80,7 @@ class WeaviateNeo4jRetriever(ExternalRetriever): embedder (Optional[Embedder]): Embedder object to embed query text. return_properties (Optional[list[str]]): List of node properties to return. result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem. - neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Raises: RetrieverInitializationError: If validation of the input arguments fail. diff --git a/src/neo4j_graphrag/retrievers/hybrid.py b/src/neo4j_graphrag/retrievers/hybrid.py index ac6006908..c1b97442b 100644 --- a/src/neo4j_graphrag/retrievers/hybrid.py +++ b/src/neo4j_graphrag/retrievers/hybrid.py @@ -70,8 +70,8 @@ class HybridRetriever(Retriever): fulltext_index_name (str): Fulltext index name. embedder (Optional[Embedder]): Embedder object to embed query text. return_properties (Optional[list[str]]): List of node properties to return. - neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem. + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Two variables are provided in the neo4j.Record: @@ -241,7 +241,7 @@ class HybridCypherRetriever(Retriever): retrieval_query (str): Cypher query that gets appended. embedder (Optional[Embedder]): Embedder object to embed query text. result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem. - neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Raises: RetrieverInitializationError: If validation of the input arguments fail. diff --git a/src/neo4j_graphrag/retrievers/text2cypher.py b/src/neo4j_graphrag/retrievers/text2cypher.py index 2bb2aa699..8297f1238 100644 --- a/src/neo4j_graphrag/retrievers/text2cypher.py +++ b/src/neo4j_graphrag/retrievers/text2cypher.py @@ -71,6 +71,7 @@ def __init__( Callable[[neo4j.Record], RetrieverResultItem] ] = None, custom_prompt: Optional[str] = None, + neo4j_database: Optional[str] = None, ) -> None: try: driver_model = Neo4jDriverModel(driver=driver) @@ -85,11 +86,14 @@ def __init__( examples=examples, result_formatter=result_formatter, custom_prompt=custom_prompt, + neo4j_database=neo4j_database, ) except ValidationError as e: raise RetrieverInitializationError(e.errors()) from e - super().__init__(validated_data.driver_model.driver) + super().__init__( + validated_data.driver_model.driver, validated_data.neo4j_database + ) self.llm = validated_data.llm_model.llm self.examples = validated_data.examples self.result_formatter = validated_data.result_formatter @@ -162,7 +166,9 @@ def get_search_results( llm_result = self.llm.invoke(prompt) t2c_query = llm_result.content logger.debug("Text2CypherRetriever Cypher query: %s", t2c_query) - records, _, _ = self.driver.execute_query(query_=t2c_query) + records, _, _ = self.driver.execute_query( + query_=t2c_query, database_=self.neo4j_database + ) except CypherSyntaxError as e: raise Text2CypherRetrievalError( f"Failed to get search result: {e.message}" diff --git a/src/neo4j_graphrag/retrievers/vector.py b/src/neo4j_graphrag/retrievers/vector.py index ed3e24d7d..cfec83652 100644 --- a/src/neo4j_graphrag/retrievers/vector.py +++ b/src/neo4j_graphrag/retrievers/vector.py @@ -78,7 +78,7 @@ class VectorRetriever(Retriever): - node: Represents the node retrieved from the vector index search. - score: Denotes the similarity score. - neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Raises: RetrieverInitializationError: If validation of the input arguments fail. @@ -243,7 +243,7 @@ class VectorCypherRetriever(Retriever): retrieval_query (str): Cypher query that gets appended. embedder (Optional[Embedder]): Embedder object to embed query text. result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem. - neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation `_). + neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). Read more in the :ref:`User Guide `. """ diff --git a/src/neo4j_graphrag/types.py b/src/neo4j_graphrag/types.py index d2aa10a33..5a45141dd 100644 --- a/src/neo4j_graphrag/types.py +++ b/src/neo4j_graphrag/types.py @@ -241,3 +241,4 @@ class Text2CypherRetrieverModel(BaseModel): examples: Optional[list[str]] = None result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None custom_prompt: Optional[str] = None + neo4j_database: Optional[str] = None diff --git a/tests/unit/experimental/components/test_kg_writer.py b/tests/unit/experimental/components/test_kg_writer.py index 0e8fb0aa9..859863bd5 100644 --- a/tests/unit/experimental/components/test_kg_writer.py +++ b/tests/unit/experimental/components/test_kg_writer.py @@ -72,6 +72,7 @@ def test_upsert_nodes(_: Mock, driver: MagicMock) -> None: } ] }, + database_=None, ) @@ -109,6 +110,7 @@ def test_upsert_nodes_with_embedding( } ] }, + database_=None, ) @@ -143,6 +145,7 @@ def test_upsert_relationship(_: Mock, driver: MagicMock) -> None: driver.execute_query.assert_called_once_with( UPSERT_RELATIONSHIP_QUERY, parameters_=parameters, + database_=None, ) @@ -179,6 +182,7 @@ def test_upsert_relationship_with_embedding(_: Mock, driver: MagicMock) -> None: driver.execute_query.assert_any_call( UPSERT_RELATIONSHIP_QUERY, parameters_=parameters, + database_=None, ) @@ -210,6 +214,7 @@ async def test_run(_: Mock, driver: MagicMock) -> None: } ] }, + database_=None, ) parameters_ = { "rows": [ @@ -225,6 +230,7 @@ async def test_run(_: Mock, driver: MagicMock) -> None: driver.execute_query.assert_any_call( UPSERT_RELATIONSHIP_QUERY, parameters_=parameters_, + database_=None, ) @@ -257,6 +263,7 @@ async def test_run_is_version_below_5_23(_: Mock) -> None: } ] }, + database_=None, ) parameters_ = { "rows": [ @@ -272,6 +279,7 @@ async def test_run_is_version_below_5_23(_: Mock) -> None: driver.execute_query.assert_any_call( UPSERT_RELATIONSHIP_QUERY, parameters_=parameters_, + database_=None, ) @@ -305,6 +313,7 @@ async def test_run_is_version_5_23_or_above(_: Mock) -> None: } ] }, + database_=None, ) parameters_ = { "rows": [ @@ -320,4 +329,5 @@ async def test_run_is_version_5_23_or_above(_: Mock) -> None: driver.execute_query.assert_any_call( UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters_, + database_=None, ) diff --git a/tests/unit/experimental/components/test_neo4j_reader.py b/tests/unit/experimental/components/test_neo4j_reader.py index 39707e135..fec8addef 100644 --- a/tests/unit/experimental/components/test_neo4j_reader.py +++ b/tests/unit/experimental/components/test_neo4j_reader.py @@ -29,11 +29,12 @@ async def test_neo4j_chunk_reader(driver: Mock) -> None: None, None, ) - chunk_reader = Neo4jChunkReader(driver) + chunk_reader = Neo4jChunkReader(driver, neo4j_database="mydb") res = await chunk_reader.run() driver.execute_query.assert_called_once_with( - "MATCH (c:`Chunk`) RETURN c { .*, embedding: null } as chunk ORDER BY c.index" + "MATCH (c:`Chunk`) RETURN c { .*, embedding: null } as chunk ORDER BY c.index", + database_="mydb", ) assert isinstance(res, TextChunks) @@ -72,7 +73,8 @@ async def test_neo4j_chunk_reader_custom_lg_config(driver: Mock) -> None: ) driver.execute_query.assert_called_once_with( - "MATCH (c:`Page`) RETURN c { .*, embedding: null } as chunk ORDER BY c.k" + "MATCH (c:`Page`) RETURN c { .*, embedding: null } as chunk ORDER BY c.k", + database_=None, ) assert isinstance(res, TextChunks) @@ -106,7 +108,8 @@ async def test_neo4j_chunk_reader_fetch_embedding(driver: Mock) -> None: res = await chunk_reader.run() driver.execute_query.assert_called_once_with( - "MATCH (c:`Chunk`) RETURN c { .* } as chunk ORDER BY c.index" + "MATCH (c:`Chunk`) RETURN c { .* } as chunk ORDER BY c.index", + database_=None, ) assert isinstance(res, TextChunks) diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index dddfca622..a23f12f4f 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -116,8 +116,13 @@ def test_t2c_retriever_happy_path( query_text = "may thy knife chip and shatter" neo4j_schema = "dummy-schema" examples = ["example-1", "example-2"] + neo4j_database = "mydb" retriever = Text2CypherRetriever( - driver=driver, llm=llm, neo4j_schema=neo4j_schema, examples=examples + driver=driver, + llm=llm, + neo4j_schema=neo4j_schema, + examples=examples, + neo4j_database=neo4j_database, ) llm.invoke.return_value = LLMResponse(content=t2c_query) driver.execute_query.return_value = ( @@ -133,7 +138,9 @@ def test_t2c_retriever_happy_path( ) retriever.search(query_text=query_text) llm.invoke.assert_called_once_with(prompt) - driver.execute_query.assert_called_once_with(query_=t2c_query) + driver.execute_query.assert_called_once_with( + query_=t2c_query, database_=neo4j_database + ) @patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version")