Skip to content

Commit b8f3938

Browse files
committed
Add an node label to filter on for external retriever base class and WeaviateRetriever
1 parent c7653df commit b8f3938

File tree

5 files changed

+31
-3
lines changed

5 files changed

+31
-3
lines changed

src/neo4j_graphrag/retrievers/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,10 +454,12 @@ def __init__(
454454
id_property_external: str,
455455
id_property_neo4j: str,
456456
neo4j_database: Optional[str] = None,
457+
node_label_neo4j: Optional[str] = None,
457458
):
458459
super().__init__(driver)
459460
self.id_property_external = id_property_external
460461
self.id_property_neo4j = id_property_neo4j
462+
self.node_label_neo4j = node_label_neo4j
461463
self.neo4j_database = neo4j_database
462464

463465
@abstractmethod

src/neo4j_graphrag/retrievers/external/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020

2121

2222
def get_match_query(
23-
return_properties: Optional[list[str]] = None, retrieval_query: Optional[str] = None
23+
return_properties: Optional[list[str]] = None,
24+
retrieval_query: Optional[str] = None,
25+
node_label: Optional[str] = None,
2426
) -> str:
27+
node_label_filter = f":`{node_label}`" if node_label else ""
2528
match_query = (
2629
"UNWIND $match_params AS match_param "
2730
"WITH match_param[0] AS match_id_value, match_param[1] AS score "
28-
"MATCH (node) "
31+
f"MATCH (node{node_label_filter}) "
2932
"WHERE node[$id_property] = match_id_value "
3033
)
3134
return match_query + get_query_tail(

src/neo4j_graphrag/retrievers/external/weaviate/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class WeaviateNeo4jRetrieverModel(BaseModel):
5757
retrieval_query: Optional[str] = None
5858
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
5959
neo4j_database: Optional[str] = None
60+
node_label_neo4j: Optional[str] = None
6061

6162

6263
class WeaviateNeo4jSearchModel(VectorSearchModel):

src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(
100100
Callable[[neo4j.Record], RetrieverResultItem]
101101
] = None,
102102
neo4j_database: Optional[str] = None,
103+
node_label_neo4j: Optional[str] = None,
103104
):
104105
try:
105106
driver_model = Neo4jDriverModel(driver=driver)
@@ -116,12 +117,17 @@ def __init__(
116117
retrieval_query=retrieval_query,
117118
result_formatter=result_formatter,
118119
neo4j_database=neo4j_database,
120+
node_label_neo4j=node_label_neo4j,
119121
)
120122
except ValidationError as e:
121123
raise RetrieverInitializationError(e.errors()) from e
122124

123125
super().__init__(
124-
driver, id_property_external, id_property_neo4j, neo4j_database
126+
driver,
127+
id_property_external,
128+
id_property_neo4j,
129+
neo4j_database,
130+
node_label_neo4j,
125131
)
126132
self.client = validated_data.client_model.client
127133
collection = validated_data.collection
@@ -234,6 +240,7 @@ def get_search_results(
234240
search_query = get_match_query(
235241
return_properties=self.return_properties,
236242
retrieval_query=self.retrieval_query,
243+
node_label=self.node_label_neo4j,
237244
)
238245

239246
parameters = {

tests/unit/retrievers/external/test_weaviate.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,21 @@ def test_match_query_with_both_return_properties_and_retrieval_query() -> None:
260260
assert match_query.strip() == expected.strip()
261261

262262

263+
def test_match_query_with_custom_node_label() -> None:
264+
# Should ignore return_properties
265+
match_query = get_match_query(
266+
return_properties=["name", "age"], node_label="MyNodeLabel"
267+
)
268+
expected = (
269+
"UNWIND $match_params AS match_param "
270+
"WITH match_param[0] AS match_id_value, match_param[1] AS score "
271+
"MATCH (node:`MyNodeLabel`) "
272+
"WHERE node[$id_property] = match_id_value "
273+
"RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score "
274+
)
275+
assert match_query.strip() == expected.strip()
276+
277+
263278
def test_weaviate_retriever_with_result_format_function(
264279
driver: MagicMock, neo4j_record: MagicMock, result_formatter: MagicMock
265280
) -> None:

0 commit comments

Comments
 (0)