Skip to content

Commit 3a66048

Browse files
committed
added metadata filtering to existing search methods
1 parent d764289 commit 3a66048

File tree

1 file changed

+52
-59
lines changed

1 file changed

+52
-59
lines changed

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

Lines changed: 52 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
CONTENT_COLUMNS = "content_id, kind, text_content, attributes_blob, metadata_s, links_blob"
3131

32-
SELECT_CQL_TEMPLATE = "SELECT {columns} FROM {table_name} {where_clause} {limit_clause};"
32+
SELECT_CQL_TEMPLATE = "SELECT {columns} FROM {table_name} {where_clause} {order_clause} {limit_clause};"
3333

3434
@dataclass
3535
class Node:
@@ -198,28 +198,6 @@ def __init__(
198198
""" # noqa: S608
199199
)
200200

201-
self._query_by_embedding = session.prepare(
202-
f"""
203-
SELECT {CONTENT_COLUMNS}
204-
FROM {keyspace}.{node_table}
205-
ORDER BY text_embedding ANN OF ?
206-
LIMIT ?
207-
""" # noqa: S608
208-
)
209-
self._query_by_embedding.consistency_level = ConsistencyLevel.ONE
210-
211-
self._query_ids_and_link_to_tags_by_embedding = session.prepare(
212-
f"""
213-
SELECT content_id, link_to_tags
214-
FROM {keyspace}.{node_table}
215-
ORDER BY text_embedding ANN OF ?
216-
LIMIT ?
217-
""" # noqa: S608
218-
)
219-
self._query_ids_and_link_to_tags_by_embedding.consistency_level = (
220-
ConsistencyLevel.ONE
221-
)
222-
223201
self._query_ids_and_link_to_tags_by_id = session.prepare(
224202
f"""
225203
SELECT content_id, link_to_tags
@@ -228,18 +206,6 @@ def __init__(
228206
""" # noqa: S608
229207
)
230208

231-
self._query_ids_and_embedding_by_embedding = session.prepare(
232-
f"""
233-
SELECT content_id, text_embedding, link_to_tags
234-
FROM {keyspace}.{node_table}
235-
ORDER BY text_embedding ANN OF ?
236-
LIMIT ?
237-
""" # noqa: S608
238-
)
239-
self._query_ids_and_embedding_by_embedding.consistency_level = (
240-
ConsistencyLevel.ONE
241-
)
242-
243209
self._query_source_tags_by_id = session.prepare(
244210
f"""
245211
SELECT link_to_tags
@@ -270,11 +236,14 @@ def __init__(
270236
"""
271237
)
272238

239+
def table_name(self) -> str:
240+
return f"{self._keyspace}.{self._node_table}"
241+
273242
def _apply_schema(self) -> None:
274243
"""Apply the schema to the database."""
275244
embedding_dim = len(self._embedding.embed_query("Test Query"))
276245
self._session.execute(f"""
277-
CREATE TABLE IF NOT EXISTS {self._keyspace}.{self._node_table} (
246+
CREATE TABLE IF NOT EXISTS {self.table_name()} (
278247
content_id TEXT,
279248
kind TEXT,
280249
text_content TEXT,
@@ -293,19 +262,19 @@ def _apply_schema(self) -> None:
293262
# Index on text_embedding (for similarity search)
294263
self._session.execute(f"""
295264
CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_text_embedding_index
296-
ON {self._keyspace}.{self._node_table}(text_embedding)
265+
ON {self.table_name()}(text_embedding)
297266
USING 'StorageAttachedIndex';
298267
""")
299268

300269
self._session.execute(f"""
301270
CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_link_from_tags
302-
ON {self._keyspace}.{self._node_table}(link_from_tags)
271+
ON {self.table_name()}(link_from_tags)
303272
USING 'StorageAttachedIndex';
304273
""")
305274

306275
self._session.execute(f"""
307276
CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_metadata_index
308-
ON {self._keyspace}.{self._node_table}(ENTRIES(metadata_s))
277+
ON {self.table_name()}(ENTRIES(metadata_s))
309278
USING 'StorageAttachedIndex';
310279
""")
311280

@@ -425,6 +394,7 @@ def mmr_traversal_search(
425394
adjacent_k: int = 10,
426395
lambda_mult: float = 0.5,
427396
score_threshold: float = float("-inf"),
397+
metadata: Optional[Dict[str, Any]] = [],
428398
) -> Iterable[Node]:
429399
"""Retrieve documents from this graph store using MMR-traversal.
430400
@@ -450,6 +420,7 @@ def mmr_traversal_search(
450420
diversity and 1 to minimum diversity. Defaults to 0.5.
451421
score_threshold: Only documents with a score greater than or equal
452422
this threshold will be chosen. Defaults to -infinity.
423+
metadata: Optional metadata to filter the results.
453424
"""
454425
query_embedding = self._embedding.embed_query(query)
455426
helper = MmrHelper(
@@ -465,10 +436,14 @@ def mmr_traversal_search(
465436
# Fetch the initial candidates and add them to the helper and
466437
# outgoing_tags.
467438
def fetch_initial_candidates() -> None:
468-
fetched = self._session.execute(
469-
self._query_ids_and_embedding_by_embedding,
470-
(query_embedding, fetch_k),
439+
query, params = self._get_search_cql(
440+
limit=fetch_k,
441+
columns="content_id, text_embedding, link_to_tags",
442+
metadata=metadata,
443+
embedding=query_embedding
471444
)
445+
446+
fetched = self._session.execute(query=query, parameters=params)
472447
candidates = {}
473448
for row in fetched:
474449
candidates[row.content_id] = row.text_embedding
@@ -540,7 +515,7 @@ def fetch_initial_candidates() -> None:
540515
return self._nodes_with_ids(helper.selected_ids)
541516

542517
def traversal_search(
543-
self, query: str, *, k: int = 4, depth: int = 1
518+
self, query: str, *, k: int = 4, depth: int = 1, metadata: Optional[Dict[str, Any]] = [],
544519
) -> Iterable[Node]:
545520
"""Retrieve documents from this knowledge store.
546521
@@ -553,6 +528,7 @@ def traversal_search(
553528
k: The number of Documents to return from the initial vector search.
554529
Defaults to 4.
555530
depth: The maximum depth of edges to traverse. Defaults to 1.
531+
metadata: Optional metadata to filter the results.
556532
557533
Returns:
558534
Collection of retrieved documents.
@@ -638,9 +614,15 @@ def visit_targets(d: int, targets: Sequence[Any]) -> None:
638614
)
639615

640616
query_embedding = self._embedding.embed_query(query)
617+
query, params = self._get_search_cql(
618+
limit=k,
619+
metadata=metadata,
620+
embedding=query_embedding,
621+
)
622+
641623
cq.execute(
642-
self._query_ids_and_link_to_tags_by_embedding,
643-
parameters=(query_embedding, k),
624+
query,
625+
parameters=params,
644626
callback=lambda nodes: visit_nodes(0, nodes),
645627
)
646628

@@ -650,17 +632,18 @@ def similarity_search(
650632
self,
651633
embedding: List[float],
652634
k: int = 4,
635+
metadata: Optional[Dict[str, Any]] = [],
653636
) -> Iterable[Node]:
654-
"""Retrieve nodes similar to the given embedding."""
655-
for row in self._session.execute(self._query_by_embedding, (embedding, k)):
637+
"""Retrieve nodes similar to the given embedding, optionally filtered by metadata"""
638+
query, params = self._get_search_cql(embedding=embedding, limit=k, metadata=metadata)
639+
640+
for row in self._session.execute(query, params):
656641
yield _row_to_node(row)
657642

658643
def metadata_search(self, metadata: Dict[str, Any] = {}, n: Optional[int] = 5)-> Iterable[Node]:
659-
query, params = self._get_metadata_search_cql(metadata=metadata, n=n)
660-
661-
prepared_query = self._session.prepare(query)
644+
query, params = self._get_search_cql(metadata=metadata, limit=n)
662645

663-
for row in self._session.execute(prepared_query, params):
646+
for row in self._session.execute(query, params):
664647
yield _row_to_node(row)
665648

666649
def get_node(self, id: str) -> Node:
@@ -802,7 +785,7 @@ def _extract_where_clause_blocks(
802785
self, metadata: Dict[str, Any]
803786
) -> Tuple[str, List[Any]]:
804787

805-
attributes_blob, metadata_s = self._parse_metadata(metadata=metadata, is_query=True)
788+
_, metadata_s = self._parse_metadata(metadata=metadata, is_query=True)
806789

807790
if len(metadata_s) == 0:
808791
return "", []
@@ -818,17 +801,27 @@ def _extract_where_clause_blocks(
818801
return where_clause, vals_list
819802

820803

821-
def _get_metadata_search_cql(self, n: int, metadata: Dict[str, Any]) -> Tuple[str, Tuple[Any, ...]]:
804+
def _get_search_cql(self, limit: int, columns: Optional[str] = CONTENT_COLUMNS, metadata: Optional[Dict[str, Any]] = {}, embedding: Optional[List[float]] = None) -> Tuple[str, Tuple[Any, ...]]:
822805
where_clause, get_cql_vals = self._extract_where_clause_blocks(metadata=metadata)
823806
limit_clause = "LIMIT ?"
824-
limit_cql_vals = [n]
825-
select_vals = tuple(list(get_cql_vals) + limit_cql_vals)
826-
#
807+
limit_cql_vals = [limit]
808+
809+
order_clause=""
810+
order_cql_vals = []
811+
if embedding is not None:
812+
order_clause = "ORDER BY text_embedding ANN OF ?"
813+
order_cql_vals = [embedding]
814+
815+
select_vals = tuple(list(get_cql_vals) + order_cql_vals + limit_cql_vals)
827816
select_cql = SELECT_CQL_TEMPLATE.format(
828-
columns=CONTENT_COLUMNS,
829-
table_name=f"{self._keyspace}.{self._node_table}",
817+
columns=columns,
818+
table_name=self.table_name(),
830819
where_clause=where_clause,
820+
order_clause=order_clause,
831821
limit_clause=limit_clause,
832822

833823
)
834-
return select_cql, select_vals
824+
prepared_query = self._session.prepare(select_cql)
825+
prepared_query.consistency_level = ConsistencyLevel.ONE
826+
827+
return prepared_query, select_vals

0 commit comments

Comments
 (0)