Skip to content

Commit 5270d23

Browse files
committed
Add node_label_neo4j parameter to Qdrant and Pinecone retrievers
1 parent b8f3938 commit 5270d23

File tree

10 files changed

+30
-4
lines changed

10 files changed

+30
-4
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
## Next
44

5+
### Added
6+
7+
- Added an optional `node_label_neo4j` parameter in the external retrievers to speed up the search query in Neo4j.
8+
9+
510
## 1.10.1
611

712
### Added

src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class PineconeNeo4jRetriever(ExternalRetriever):
8383
retrieval_query (str): Cypher query that gets appended.
8484
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
8585
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
86+
node_label_neo4j (Optional[str]): The label of the Neo4j node to retrieve.
8687
8788
Raises:
8889
RetrieverInitializationError: If validation of the input arguments fail.
@@ -101,6 +102,7 @@ def __init__(
101102
Callable[[neo4j.Record], RetrieverResultItem]
102103
] = None,
103104
neo4j_database: Optional[str] = None,
105+
node_label_neo4j: Optional[str] = None,
104106
):
105107
try:
106108
driver_model = Neo4jDriverModel(driver=driver)
@@ -116,6 +118,7 @@ def __init__(
116118
retrieval_query=retrieval_query,
117119
result_formatter=result_formatter,
118120
neo4j_database=neo4j_database,
121+
node_label_neo4j=node_label_neo4j,
119122
)
120123
except ValidationError as e:
121124
raise RetrieverInitializationError(e.errors()) from e
@@ -125,6 +128,7 @@ def __init__(
125128
id_property_external="id",
126129
id_property_neo4j=validated_data.id_property_neo4j,
127130
neo4j_database=neo4j_database,
131+
node_label_neo4j=node_label_neo4j,
128132
)
129133
self.driver = validated_data.driver_model.driver
130134
self.client = validated_data.client_model.client
@@ -172,7 +176,8 @@ def get_search_results(
172176
driver=neo4j_driver,
173177
client=pc_client,
174178
index_name="jeopardy",
175-
id_property_neo4j="id"
179+
id_property_neo4j="id",
180+
node_label_neo4j="Document",
176181
)
177182
biology_embedding = ...
178183
retriever.search(query_vector=biology_embedding, top_k=2)
@@ -223,6 +228,7 @@ def get_search_results(
223228
search_query = get_match_query(
224229
return_properties=self.return_properties,
225230
retrieval_query=self.retrieval_query,
231+
node_label=self.node_label_neo4j,
226232
)
227233

228234
parameters = {

src/neo4j_graphrag/retrievers/external/pinecone/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,4 @@ class PineconeNeo4jRetrieverModel(BaseModel):
5959
retrieval_query: Optional[str] = None
6060
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
6161
neo4j_database: Optional[str] = None
62+
node_label_neo4j: Optional[str] = None

src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class QdrantNeo4jRetriever(ExternalRetriever):
7979
return_properties (Optional[list[str]]): List of node properties to return.
8080
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
8181
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
82+
node_label_neo4j (Optional[str]): The label of the Neo4j node to retrieve.
8283
8384
Raises:
8485
RetrieverInitializationError: If validation of the input arguments fail.
@@ -99,6 +100,7 @@ def __init__(
99100
Callable[[neo4j.Record], RetrieverResultItem]
100101
] = None,
101102
neo4j_database: Optional[str] = None,
103+
node_label_neo4j: Optional[str] = None,
102104
):
103105
try:
104106
driver_model = Neo4jDriverModel(driver=driver)
@@ -116,6 +118,7 @@ def __init__(
116118
retrieval_query=retrieval_query,
117119
result_formatter=result_formatter,
118120
neo4j_database=neo4j_database,
121+
node_label_neo4j=node_label_neo4j,
119122
)
120123
except ValidationError as e:
121124
raise RetrieverInitializationError(e.errors()) from e
@@ -125,6 +128,7 @@ def __init__(
125128
id_property_external=validated_data.id_property_external,
126129
id_property_neo4j=validated_data.id_property_neo4j,
127130
neo4j_database=neo4j_database,
131+
node_label_neo4j=node_label_neo4j,
128132
)
129133
self.driver = validated_data.driver_model.driver
130134
self.client = validated_data.client_model.client
@@ -169,7 +173,8 @@ def get_search_results(
169173
driver=neo4j_driver,
170174
client=client,
171175
collection_name="my_collection",
172-
id_property_external="neo4j_id"
176+
id_property_external="neo4j_id",
177+
node_label_neo4j="Document",
173178
)
174179
embedding = ...
175180
retriever.search(query_vector=embedding, top_k=2)
@@ -223,6 +228,7 @@ def get_search_results(
223228
search_query = get_match_query(
224229
return_properties=self.return_properties,
225230
retrieval_query=self.retrieval_query,
231+
node_label=self.node_label_neo4j,
226232
)
227233

228234
parameters = {

src/neo4j_graphrag/retrievers/external/qdrant/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,4 @@ class QdrantNeo4jRetrieverModel(BaseModel):
5454
retrieval_query: Optional[str] = None
5555
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
5656
neo4j_database: Optional[str] = None
57+
node_label_neo4j: Optional[str] = None

src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class WeaviateNeo4jRetriever(ExternalRetriever):
8181
return_properties (Optional[list[str]]): List of node properties to return.
8282
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
8383
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
84+
node_label_neo4j (Optional[str]): The label of the Neo4j node to retrieve.
8485
8586
Raises:
8687
RetrieverInitializationError: If validation of the input arguments fail.
@@ -170,6 +171,7 @@ def get_search_results(
170171
collection="Jeopardy",
171172
id_property_external="neo4j_id",
172173
id_property_neo4j="id",
174+
node_label_neo4j="Document",
173175
)
174176
175177
biology_embedding = ...

tests/e2e/pinecone_e2e/test_pinecone_e2e.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ def populate_neo4j_db(driver: MagicMock) -> None:
5252
@pytest.mark.usefixtures("populate_neo4j_db")
5353
def test_pinecone_neo4j_vector_input(driver: MagicMock, client: MagicMock) -> None:
5454
retriever = PineconeNeo4jRetriever(
55-
driver=driver, client=client, index_name="jeopardy", id_property_neo4j="id"
55+
driver=driver,
56+
client=client,
57+
index_name="jeopardy",
58+
id_property_neo4j="id",
59+
node_label_neo4j="Question",
5660
)
5761
with mock.patch.object(retriever, "index") as mock_index:
5862
top_k = 2

tests/e2e/qdrant_e2e/test_qdrant_e2e.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_qdrant_neo4j_vector_input(driver: Driver, qdrant_client: QdrantClient)
5959
collection_name="Jeopardy",
6060
id_property_external="neo4j_id",
6161
id_property_neo4j="id",
62+
node_label_neo4j="Question",
6263
)
6364

6465
top_k = 1

tests/e2e/weaviate_e2e/test_weaviate_e2e.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def test_weaviate_neo4j_vector_input(
6363
collection="Jeopardy",
6464
id_property_external="neo4j_id",
6565
id_property_neo4j="id",
66+
node_label_neo4j="Question",
6667
)
6768

6869
top_k = 2

tests/unit/retrievers/external/test_weaviate.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ def test_match_query_with_both_return_properties_and_retrieval_query() -> None:
261261

262262

263263
def test_match_query_with_custom_node_label() -> None:
264-
# Should ignore return_properties
265264
match_query = get_match_query(
266265
return_properties=["name", "age"], node_label="MyNodeLabel"
267266
)

0 commit comments

Comments
 (0)