Skip to content

Commit 68b45d9

Browse files
committed
Add id property getter to qdrant retriever
1 parent 03ddb3c commit 68b45d9

File tree

3 files changed

+80
-5
lines changed

3 files changed

+80
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
### Added
66

77
- Added an optional `node_label_neo4j` parameter in the external retrievers to speed up the search query in Neo4j.
8-
8+
- Added an optional `id_property_getter` callable parameter in the Qdrant retriever to allow for custom ID retrieval.
99

1010
## 1.10.1
1111

src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import neo4j
2121
from pydantic import ValidationError
2222
from qdrant_client import QdrantClient
23+
from qdrant_client.conversions.common_types import ScoredPoint
2324

2425
from neo4j_graphrag.embeddings.base import Embedder
2526
from neo4j_graphrag.exceptions import (
@@ -80,6 +81,7 @@ class QdrantNeo4jRetriever(ExternalRetriever):
8081
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
8182
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
8283
node_label_neo4j (Optional[str]): The label of the Neo4j node to retrieve. This label must be properly escaped if needed, eg "`Label with spaces`".
84+
id_property_getter (Optional[Callable[[ScoredPoint], str]]): Function to get the id property from a ScoredPoint. Defaults to point.payload.get(id_property_external, point.id).
8385
8486
Raises:
8587
RetrieverInitializationError: If validation of the input arguments fail.
@@ -101,6 +103,7 @@ def __init__(
101103
] = None,
102104
neo4j_database: Optional[str] = None,
103105
node_label_neo4j: Optional[str] = None,
106+
id_property_getter: Optional[Callable[[ScoredPoint], Any]] = None,
104107
):
105108
try:
106109
driver_model = Neo4jDriverModel(driver=driver)
@@ -142,6 +145,14 @@ def __init__(
142145
self.return_properties = validated_data.return_properties
143146
self.retrieval_query = validated_data.retrieval_query
144147
self.result_formatter = validated_data.result_formatter
148+
self.id_property_getter = id_property_getter
149+
150+
def get_match_id_from_point(self, point: ScoredPoint) -> Any:
151+
if self.id_property_getter:
152+
return self.id_property_getter(point)
153+
if point.payload is None:
154+
raise ValueError(f"Payload is None for point {point}")
155+
return point.payload.get(self.id_property_external, point.id)
145156

146157
def get_search_results(
147158
self,
@@ -220,10 +231,7 @@ def get_search_results(
220231

221232
result_tuples = []
222233
for point in points:
223-
assert point.payload is not None
224-
result_tuples.append(
225-
[point.payload.get(self.id_property_external, point.id), point.score]
226-
)
234+
result_tuples.append([self.get_match_id_from_point(point), point.score])
227235

228236
search_query = get_match_query(
229237
return_properties=self.return_properties,

tests/unit/retrievers/external/test_qdrant.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,70 @@ def test_qdrant_retriever_invalid_retrieval_query(
267267

268268
assert "retrieval_query" in str(exc_info.value)
269269
assert "Input should be a valid string" in str(exc_info.value)
270+
271+
272+
def test_qdrant_retriever_search_custom_match_id_getter(
273+
driver: MagicMock, client: MagicMock
274+
) -> None:
275+
def my_id_getter(point: ScoredPoint) -> Any:
276+
if point.payload is None:
277+
raise Exception("Payload is None")
278+
return point.payload["data"]["id"]
279+
280+
retriever = QdrantNeo4jRetriever(
281+
driver=driver,
282+
client=client,
283+
collection_name="dummy-text",
284+
id_property_neo4j="sync_id",
285+
id_property_getter=my_id_getter,
286+
)
287+
with mock.patch.object(retriever, "client") as mock_client:
288+
top_k = 5
289+
mock_client.query_points.return_value = QueryResponse(
290+
points=[
291+
ScoredPoint(
292+
id=i,
293+
version=0,
294+
score=i / top_k,
295+
payload={
296+
"data": {"id": f"node_{i}"},
297+
},
298+
)
299+
for i in range(top_k)
300+
]
301+
)
302+
driver.execute_query.return_value = (
303+
[
304+
neo4j.Record({"node": {"sync_id": f"node_{i}"}, "score": i / top_k})
305+
for i in range(top_k)
306+
],
307+
None,
308+
None,
309+
)
310+
query_vector = [1.0 for _ in range(1536)]
311+
search_query = get_match_query()
312+
records = retriever.search(query_vector=query_vector)
313+
314+
driver.execute_query.assert_called_once_with(
315+
search_query,
316+
{
317+
"match_params": [[f"node_{i}", i / top_k] for i in range(top_k)],
318+
"id_property": "sync_id",
319+
},
320+
database_=None,
321+
routing_=neo4j.RoutingControl.READ,
322+
)
323+
324+
assert records == RetrieverResult(
325+
items=[
326+
RetrieverResultItem(
327+
content="<Record node={'sync_id': "
328+
+ f"'node_{i}'"
329+
+ "} "
330+
+ f"score={i / top_k}>",
331+
metadata=None,
332+
)
333+
for i in range(top_k)
334+
],
335+
metadata={"__retriever": "QdrantNeo4jRetriever"},
336+
)

0 commit comments

Comments
 (0)