Skip to content

Commit d478207

Browse files
committed
added metadata filter support to graph methods
1 parent 5b0bca4 commit d478207

File tree

3 files changed

+368
-79
lines changed

3 files changed

+368
-79
lines changed

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

Lines changed: 148 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
cast,
1919
)
2020

21-
from cassandra.cluster import ConsistencyLevel, Session
21+
from cassandra.cluster import ConsistencyLevel, PreparedStatement, Session
2222
from cassio.config import check_resolve_keyspace, check_resolve_session
2323

2424
from ._mmr_helper import MmrHelper
@@ -32,7 +32,7 @@
3232
CONTENT_COLUMNS = "content_id, kind, text_content, links_blob, metadata_blob"
3333

3434
SELECT_CQL_TEMPLATE = (
35-
"SELECT {columns} FROM {table_name} {where_clause} {order_clause} {limit_clause};"
35+
"SELECT {columns} FROM {table_name}{where_clause}{order_clause}{limit_clause};"
3636
)
3737

3838

@@ -172,6 +172,7 @@ def __init__(
172172
self._node_table = node_table
173173
self._session = session
174174
self._keyspace = keyspace
175+
self._prepared_query_cache: Dict[str, PreparedStatement] = {}
175176

176177
self._metadata_indexing_policy = self._normalize_metadata_indexing_policy(
177178
metadata_indexing=metadata_indexing,
@@ -219,28 +220,6 @@ def __init__(
219220
""" # noqa: S608
220221
)
221222

222-
self._query_targets_embeddings_by_kind_and_tag_and_embedding = session.prepare(
223-
f"""
224-
SELECT
225-
content_id AS target_content_id,
226-
text_embedding AS target_text_embedding,
227-
link_to_tags AS target_link_to_tags
228-
FROM {keyspace}.{node_table}
229-
WHERE link_from_tags CONTAINS (?, ?)
230-
ORDER BY text_embedding ANN of ?
231-
LIMIT ?
232-
"""
233-
)
234-
235-
self._query_targets_by_kind_and_value = session.prepare(
236-
f"""
237-
SELECT
238-
content_id AS target_content_id
239-
FROM {keyspace}.{node_table}
240-
WHERE link_from_tags CONTAINS (?, ?)
241-
"""
242-
)
243-
244223
def table_name(self) -> str:
245224
"""Returns the fully qualified table name."""
246225
return f"{self._keyspace}.{self._node_table}"
@@ -427,15 +406,23 @@ def mmr_traversal_search(
427406

428407
# Fetch the initial candidates and add them to the helper and
429408
# outgoing_tags.
409+
initial_candidates_query = self._get_search_cql(
410+
has_limit=True,
411+
columns="content_id, text_embedding, link_to_tags",
412+
metadata_keys=list(metadata_filter.keys()),
413+
has_embedding=True,
414+
)
415+
430416
def fetch_initial_candidates() -> None:
431-
query, params = self._get_search_cql(
417+
params = self._get_search_params(
432418
limit=fetch_k,
433-
columns="content_id, text_embedding, link_to_tags",
434419
metadata=metadata_filter,
435420
embedding=query_embedding,
436421
)
437422

438-
fetched = self._session.execute(query=query, parameters=params)
423+
fetched = self._session.execute(
424+
query=initial_candidates_query, parameters=params
425+
)
439426
candidates = {}
440427
for row in fetched:
441428
candidates[row.content_id] = row.text_embedding
@@ -474,6 +461,7 @@ def fetch_initial_candidates() -> None:
474461
link_to_tags,
475462
query_embedding=query_embedding,
476463
k_per_tag=adjacent_k,
464+
metadata_filter=metadata_filter,
477465
)
478466

479467
# Record the link_to_tags as visited.
@@ -541,6 +529,19 @@ def traversal_search(
541529
#
542530
# ...
543531

532+
traversal_query = self._get_search_cql(
533+
columns="content_id, link_to_tags",
534+
has_limit=True,
535+
metadata_keys=list(metadata_filter.keys()),
536+
has_embedding=True,
537+
)
538+
539+
visit_nodes_query = self._get_search_cql(
540+
columns="content_id AS target_content_id",
541+
has_link_from_tags=True,
542+
metadata_keys=list(metadata_filter.keys()),
543+
)
544+
544545
with self._concurrent_queries() as cq:
545546
# Map from visited ID to depth
546547
visited_ids: Dict[str, int] = {}
@@ -583,12 +584,12 @@ def visit_nodes(d: int, nodes: Sequence[Any]) -> None:
583584
# If there are new tags to visit at the next depth, query for the
584585
# node IDs.
585586
for kind, value in outgoing_tags:
587+
params = self._get_search_params(
588+
link_from_tags=(kind, value), metadata=metadata_filter
589+
)
586590
cq.execute(
587-
self._query_targets_by_kind_and_value,
588-
parameters=(
589-
kind,
590-
value,
591-
),
591+
query=visit_nodes_query,
592+
parameters=params,
592593
callback=lambda rows, d=d: visit_targets(d, rows),
593594
)
594595

@@ -611,15 +612,14 @@ def visit_targets(d: int, targets: Sequence[Any]) -> None:
611612
)
612613

613614
query_embedding = self._embedding.embed_query(query)
614-
query, params = self._get_search_cql(
615-
columns="content_id, link_to_tags",
615+
params = self._get_search_params(
616616
limit=k,
617617
metadata=metadata_filter,
618618
embedding=query_embedding,
619619
)
620620

621621
cq.execute(
622-
query,
622+
traversal_query,
623623
parameters=params,
624624
callback=lambda nodes: visit_nodes(0, nodes),
625625
)
@@ -633,7 +633,7 @@ def similarity_search(
633633
metadata_filter: Dict[str, Any] = {},
634634
) -> Iterable[Node]:
635635
"""Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501
636-
query, params = self._get_search_cql(
636+
query, params = self._get_search_cql_and_params(
637637
embedding=embedding, limit=k, metadata=metadata_filter
638638
)
639639

@@ -644,7 +644,7 @@ def metadata_search(
644644
self, metadata: Dict[str, Any] = {}, n: int = 5
645645
) -> Iterable[Node]:
646646
"""Retrieve nodes based on their metadata."""
647-
query, params = self._get_search_cql(metadata=metadata, limit=n)
647+
query, params = self._get_search_cql_and_params(metadata=metadata, limit=n)
648648

649649
for row in self._session.execute(query, params):
650650
yield _row_to_node(row)
@@ -681,19 +681,35 @@ def _get_adjacent(
681681
tags: Set[Tuple[str, str]],
682682
query_embedding: List[float],
683683
k_per_tag: Optional[int] = None,
684+
metadata_filter: Dict[str, Any] = {},
684685
) -> Iterable[_Edge]:
685686
"""Return the target nodes with incoming links from any of the given tags.
686687
687688
Args:
688689
tags: The tags to look for links *from*.
689690
query_embedding: The query embedding. Used to rank target nodes.
690691
k_per_tag: The number of target nodes to fetch for each outgoing tag.
692+
metadata_filter: Optional metadata to filter the results.
691693
692694
Returns:
693695
List of adjacent edges.
694696
"""
695697
targets: Dict[str, _Edge] = {}
696698

699+
columns = """
700+
content_id AS target_content_id,
701+
text_embedding AS target_text_embedding,
702+
link_to_tags AS target_link_to_tags
703+
"""
704+
705+
adjacent_query = self._get_search_cql(
706+
has_limit=True,
707+
columns=columns,
708+
metadata_keys=list(metadata_filter.keys()),
709+
has_embedding=True,
710+
has_link_from_tags=True,
711+
)
712+
697713
def add_targets(rows: Iterable[Any]) -> None:
698714
# TODO: Figure out how to use the "kind" on the edge.
699715
# This is tricky, since we currently issue one query for anything
@@ -709,14 +725,16 @@ def add_targets(rows: Iterable[Any]) -> None:
709725

710726
with self._concurrent_queries() as cq:
711727
for kind, value in tags:
728+
params = self._get_search_params(
729+
limit=k_per_tag or 10,
730+
metadata=metadata_filter,
731+
embedding=query_embedding,
732+
link_from_tags=(kind, value),
733+
)
734+
712735
cq.execute(
713-
self._query_targets_embeddings_by_kind_and_tag_and_embedding,
714-
parameters=(
715-
kind,
716-
value,
717-
query_embedding,
718-
k_per_tag or 10,
719-
),
736+
query=adjacent_query,
737+
parameters=params,
720738
callback=add_targets,
721739
)
722740

@@ -784,55 +802,116 @@ def _coerce_string(value: Any) -> str:
784802
# when all else fails ...
785803
return str(value)
786804

787-
def _extract_where_clause_blocks(
788-
self, metadata: Dict[str, Any]
789-
) -> Tuple[str, List[Any]]:
805+
def _extract_where_clause_cql(
806+
self,
807+
metadata_keys: List[str] = [],
808+
has_link_from_tags: bool = False,
809+
) -> str:
790810
wc_blocks: List[str] = []
791-
vals_list: List[Any] = []
792811

793-
for key, value in sorted(metadata.items()):
812+
if has_link_from_tags:
813+
wc_blocks.append("link_from_tags CONTAINS (?, ?)")
814+
815+
for key in sorted(metadata_keys):
794816
if _is_metadata_field_indexed(key, self._metadata_indexing_policy):
795817
wc_blocks.append(f"metadata_s['{key}'] = ?")
796-
vals_list.append(self._coerce_string(value=value))
797818
else:
798819
raise ValueError(
799820
"Non-indexed metadata fields cannot be used in queries."
800821
)
801822

802823
if len(wc_blocks) == 0:
803-
return "", []
824+
return ""
804825

805-
where_clause = "WHERE " + " AND ".join(wc_blocks)
806-
return where_clause, vals_list
826+
return " WHERE " + " AND ".join(wc_blocks)
827+
828+
def _extract_where_clause_params(
829+
self,
830+
metadata: Dict[str, Any],
831+
link_from_tags: Optional[Tuple[str, str]] = None,
832+
) -> List[Any]:
833+
params: List[Any] = []
834+
835+
if link_from_tags is not None:
836+
params.append(link_from_tags[0])
837+
params.append(link_from_tags[1])
838+
839+
for key, value in sorted(metadata.items()):
840+
if _is_metadata_field_indexed(key, self._metadata_indexing_policy):
841+
params.append(self._coerce_string(value=value))
842+
else:
843+
raise ValueError(
844+
"Non-indexed metadata fields cannot be used in queries."
845+
)
846+
847+
return params
807848

808849
def _get_search_cql(
809850
self,
810-
limit: int,
851+
has_limit: bool = False,
811852
columns: Optional[str] = CONTENT_COLUMNS,
812-
metadata: Dict[str, Any] = {},
813-
embedding: Optional[List[float]] = None,
814-
) -> Tuple[str, Tuple[Any, ...]]:
815-
where_clause, get_cql_vals = self._extract_where_clause_blocks(
816-
metadata=metadata
853+
metadata_keys: List[str] = [],
854+
has_embedding: bool = False,
855+
has_link_from_tags: bool = False,
856+
) -> PreparedStatement:
857+
where_clause = self._extract_where_clause_cql(
858+
metadata_keys=metadata_keys, has_link_from_tags=has_link_from_tags
817859
)
818-
limit_clause = "LIMIT ?"
819-
limit_cql_vals = [limit]
860+
limit_clause = " LIMIT ?" if has_limit else ""
861+
order_clause = " ORDER BY text_embedding ANN OF ?" if has_embedding else ""
820862

821-
order_clause = ""
822-
order_cql_vals = []
823-
if embedding is not None:
824-
order_clause = "ORDER BY text_embedding ANN OF ?"
825-
order_cql_vals = [embedding]
826-
827-
select_vals = tuple(list(get_cql_vals) + order_cql_vals + limit_cql_vals)
828863
select_cql = SELECT_CQL_TEMPLATE.format(
829864
columns=columns,
830865
table_name=self.table_name(),
831866
where_clause=where_clause,
832867
order_clause=order_clause,
833868
limit_clause=limit_clause,
834869
)
870+
871+
if select_cql in self._prepared_query_cache:
872+
return self._prepared_query_cache[select_cql]
873+
835874
prepared_query = self._session.prepare(select_cql)
836875
prepared_query.consistency_level = ConsistencyLevel.ONE
876+
self._prepared_query_cache[select_cql] = prepared_query
837877

838-
return prepared_query, select_vals
878+
return prepared_query
879+
880+
def _get_search_params(
881+
self,
882+
limit: Optional[int] = None,
883+
metadata: Dict[str, Any] = {},
884+
embedding: Optional[List[float]] = None,
885+
link_from_tags: Optional[Tuple[str, str]] = None,
886+
) -> Tuple[PreparedStatement, Tuple[Any, ...]]:
887+
where_params = self._extract_where_clause_params(
888+
metadata=metadata, link_from_tags=link_from_tags
889+
)
890+
891+
limit_params = [limit] if limit is not None else []
892+
order_params = [embedding] if embedding is not None else []
893+
894+
return tuple(list(where_params) + order_params + limit_params)
895+
896+
def _get_search_cql_and_params(
897+
self,
898+
limit: Optional[int] = None,
899+
columns: Optional[str] = CONTENT_COLUMNS,
900+
metadata: Dict[str, Any] = {},
901+
embedding: Optional[List[float]] = None,
902+
link_from_tags: Optional[Tuple[str, str]] = None,
903+
) -> Tuple[PreparedStatement, Tuple[Any, ...]]:
904+
query = self._get_search_cql(
905+
has_limit=limit is not None,
906+
columns=columns,
907+
metadata_keys=list(metadata.keys()),
908+
has_embedding=embedding is not None,
909+
has_link_from_tags=link_from_tags is not None,
910+
)
911+
params = self._get_search_params(
912+
limit=limit,
913+
metadata=metadata,
914+
embedding=embedding,
915+
link_from_tags=link_from_tags,
916+
)
917+
return query, params

0 commit comments

Comments
 (0)