18
18
cast ,
19
19
)
20
20
21
- from cassandra .cluster import ConsistencyLevel , Session
21
+ from cassandra .cluster import ConsistencyLevel , PreparedStatement , Session
22
22
from cassio .config import check_resolve_keyspace , check_resolve_session
23
23
24
24
from ._mmr_helper import MmrHelper
32
32
CONTENT_COLUMNS = "content_id, kind, text_content, links_blob, metadata_blob"
33
33
34
34
SELECT_CQL_TEMPLATE = (
35
- "SELECT {columns} FROM {table_name} {where_clause} {order_clause} {limit_clause};"
35
+ "SELECT {columns} FROM {table_name}{where_clause}{order_clause}{limit_clause};"
36
36
)
37
37
38
38
@@ -172,6 +172,7 @@ def __init__(
172
172
self ._node_table = node_table
173
173
self ._session = session
174
174
self ._keyspace = keyspace
175
+ self ._prepared_query_cache : Dict [str , PreparedStatement ] = {}
175
176
176
177
self ._metadata_indexing_policy = self ._normalize_metadata_indexing_policy (
177
178
metadata_indexing = metadata_indexing ,
@@ -219,28 +220,6 @@ def __init__(
219
220
""" # noqa: S608
220
221
)
221
222
222
- self ._query_targets_embeddings_by_kind_and_tag_and_embedding = session .prepare (
223
- f"""
224
- SELECT
225
- content_id AS target_content_id,
226
- text_embedding AS target_text_embedding,
227
- link_to_tags AS target_link_to_tags
228
- FROM { keyspace } .{ node_table }
229
- WHERE link_from_tags CONTAINS (?, ?)
230
- ORDER BY text_embedding ANN of ?
231
- LIMIT ?
232
- """
233
- )
234
-
235
- self ._query_targets_by_kind_and_value = session .prepare (
236
- f"""
237
- SELECT
238
- content_id AS target_content_id
239
- FROM { keyspace } .{ node_table }
240
- WHERE link_from_tags CONTAINS (?, ?)
241
- """
242
- )
243
-
244
223
def table_name (self ) -> str :
245
224
"""Returns the fully qualified table name."""
246
225
return f"{ self ._keyspace } .{ self ._node_table } "
@@ -427,15 +406,23 @@ def mmr_traversal_search(
427
406
428
407
# Fetch the initial candidates and add them to the helper and
429
408
# outgoing_tags.
409
+ initial_candidates_query = self ._get_search_cql (
410
+ has_limit = True ,
411
+ columns = "content_id, text_embedding, link_to_tags" ,
412
+ metadata_keys = list (metadata_filter .keys ()),
413
+ has_embedding = True ,
414
+ )
415
+
430
416
def fetch_initial_candidates () -> None :
431
- query , params = self ._get_search_cql (
417
+ params = self ._get_search_params (
432
418
limit = fetch_k ,
433
- columns = "content_id, text_embedding, link_to_tags" ,
434
419
metadata = metadata_filter ,
435
420
embedding = query_embedding ,
436
421
)
437
422
438
- fetched = self ._session .execute (query = query , parameters = params )
423
+ fetched = self ._session .execute (
424
+ query = initial_candidates_query , parameters = params
425
+ )
439
426
candidates = {}
440
427
for row in fetched :
441
428
candidates [row .content_id ] = row .text_embedding
@@ -474,6 +461,7 @@ def fetch_initial_candidates() -> None:
474
461
link_to_tags ,
475
462
query_embedding = query_embedding ,
476
463
k_per_tag = adjacent_k ,
464
+ metadata_filter = metadata_filter ,
477
465
)
478
466
479
467
# Record the link_to_tags as visited.
@@ -541,6 +529,19 @@ def traversal_search(
541
529
#
542
530
# ...
543
531
532
+ traversal_query = self ._get_search_cql (
533
+ columns = "content_id, link_to_tags" ,
534
+ has_limit = True ,
535
+ metadata_keys = list (metadata_filter .keys ()),
536
+ has_embedding = True ,
537
+ )
538
+
539
+ visit_nodes_query = self ._get_search_cql (
540
+ columns = "content_id AS target_content_id" ,
541
+ has_link_from_tags = True ,
542
+ metadata_keys = list (metadata_filter .keys ()),
543
+ )
544
+
544
545
with self ._concurrent_queries () as cq :
545
546
# Map from visited ID to depth
546
547
visited_ids : Dict [str , int ] = {}
@@ -583,12 +584,12 @@ def visit_nodes(d: int, nodes: Sequence[Any]) -> None:
583
584
# If there are new tags to visit at the next depth, query for the
584
585
# node IDs.
585
586
for kind , value in outgoing_tags :
587
+ params = self ._get_search_params (
588
+ link_from_tags = (kind , value ), metadata = metadata_filter
589
+ )
586
590
cq .execute (
587
- self ._query_targets_by_kind_and_value ,
588
- parameters = (
589
- kind ,
590
- value ,
591
- ),
591
+ query = visit_nodes_query ,
592
+ parameters = params ,
592
593
callback = lambda rows , d = d : visit_targets (d , rows ),
593
594
)
594
595
@@ -611,15 +612,14 @@ def visit_targets(d: int, targets: Sequence[Any]) -> None:
611
612
)
612
613
613
614
query_embedding = self ._embedding .embed_query (query )
614
- query , params = self ._get_search_cql (
615
- columns = "content_id, link_to_tags" ,
615
+ params = self ._get_search_params (
616
616
limit = k ,
617
617
metadata = metadata_filter ,
618
618
embedding = query_embedding ,
619
619
)
620
620
621
621
cq .execute (
622
- query ,
622
+ traversal_query ,
623
623
parameters = params ,
624
624
callback = lambda nodes : visit_nodes (0 , nodes ),
625
625
)
@@ -633,7 +633,7 @@ def similarity_search(
633
633
metadata_filter : Dict [str , Any ] = {},
634
634
) -> Iterable [Node ]:
635
635
"""Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501
636
- query , params = self ._get_search_cql (
636
+ query , params = self ._get_search_cql_and_params (
637
637
embedding = embedding , limit = k , metadata = metadata_filter
638
638
)
639
639
@@ -644,7 +644,7 @@ def metadata_search(
644
644
self , metadata : Dict [str , Any ] = {}, n : int = 5
645
645
) -> Iterable [Node ]:
646
646
"""Retrieve nodes based on their metadata."""
647
- query , params = self ._get_search_cql (metadata = metadata , limit = n )
647
+ query , params = self ._get_search_cql_and_params (metadata = metadata , limit = n )
648
648
649
649
for row in self ._session .execute (query , params ):
650
650
yield _row_to_node (row )
@@ -681,19 +681,35 @@ def _get_adjacent(
681
681
tags : Set [Tuple [str , str ]],
682
682
query_embedding : List [float ],
683
683
k_per_tag : Optional [int ] = None ,
684
+ metadata_filter : Dict [str , Any ] = {},
684
685
) -> Iterable [_Edge ]:
685
686
"""Return the target nodes with incoming links from any of the given tags.
686
687
687
688
Args:
688
689
tags: The tags to look for links *from*.
689
690
query_embedding: The query embedding. Used to rank target nodes.
690
691
k_per_tag: The number of target nodes to fetch for each outgoing tag.
692
+ metadata_filter: Optional metadata to filter the results.
691
693
692
694
Returns:
693
695
List of adjacent edges.
694
696
"""
695
697
targets : Dict [str , _Edge ] = {}
696
698
699
+ columns = """
700
+ content_id AS target_content_id,
701
+ text_embedding AS target_text_embedding,
702
+ link_to_tags AS target_link_to_tags
703
+ """
704
+
705
+ adjacent_query = self ._get_search_cql (
706
+ has_limit = True ,
707
+ columns = columns ,
708
+ metadata_keys = list (metadata_filter .keys ()),
709
+ has_embedding = True ,
710
+ has_link_from_tags = True ,
711
+ )
712
+
697
713
def add_targets (rows : Iterable [Any ]) -> None :
698
714
# TODO: Figure out how to use the "kind" on the edge.
699
715
# This is tricky, since we currently issue one query for anything
@@ -709,14 +725,16 @@ def add_targets(rows: Iterable[Any]) -> None:
709
725
710
726
with self ._concurrent_queries () as cq :
711
727
for kind , value in tags :
728
+ params = self ._get_search_params (
729
+ limit = k_per_tag or 10 ,
730
+ metadata = metadata_filter ,
731
+ embedding = query_embedding ,
732
+ link_from_tags = (kind , value ),
733
+ )
734
+
712
735
cq .execute (
713
- self ._query_targets_embeddings_by_kind_and_tag_and_embedding ,
714
- parameters = (
715
- kind ,
716
- value ,
717
- query_embedding ,
718
- k_per_tag or 10 ,
719
- ),
736
+ query = adjacent_query ,
737
+ parameters = params ,
720
738
callback = add_targets ,
721
739
)
722
740
@@ -784,55 +802,116 @@ def _coerce_string(value: Any) -> str:
784
802
# when all else fails ...
785
803
return str (value )
786
804
787
- def _extract_where_clause_blocks (
788
- self , metadata : Dict [str , Any ]
789
- ) -> Tuple [str , List [Any ]]:
805
+ def _extract_where_clause_cql (
806
+ self ,
807
+ metadata_keys : List [str ] = [],
808
+ has_link_from_tags : bool = False ,
809
+ ) -> str :
790
810
wc_blocks : List [str ] = []
791
- vals_list : List [Any ] = []
792
811
793
- for key , value in sorted (metadata .items ()):
812
+ if has_link_from_tags :
813
+ wc_blocks .append ("link_from_tags CONTAINS (?, ?)" )
814
+
815
+ for key in sorted (metadata_keys ):
794
816
if _is_metadata_field_indexed (key , self ._metadata_indexing_policy ):
795
817
wc_blocks .append (f"metadata_s['{ key } '] = ?" )
796
- vals_list .append (self ._coerce_string (value = value ))
797
818
else :
798
819
raise ValueError (
799
820
"Non-indexed metadata fields cannot be used in queries."
800
821
)
801
822
802
823
if len (wc_blocks ) == 0 :
803
- return "" , []
824
+ return ""
804
825
805
- where_clause = "WHERE " + " AND " .join (wc_blocks )
806
- return where_clause , vals_list
826
+ return " WHERE " + " AND " .join (wc_blocks )
827
+
828
+ def _extract_where_clause_params (
829
+ self ,
830
+ metadata : Dict [str , Any ],
831
+ link_from_tags : Optional [Tuple [str , str ]] = None ,
832
+ ) -> List [Any ]:
833
+ params : List [Any ] = []
834
+
835
+ if link_from_tags is not None :
836
+ params .append (link_from_tags [0 ])
837
+ params .append (link_from_tags [1 ])
838
+
839
+ for key , value in sorted (metadata .items ()):
840
+ if _is_metadata_field_indexed (key , self ._metadata_indexing_policy ):
841
+ params .append (self ._coerce_string (value = value ))
842
+ else :
843
+ raise ValueError (
844
+ "Non-indexed metadata fields cannot be used in queries."
845
+ )
846
+
847
+ return params
807
848
808
849
def _get_search_cql (
809
850
self ,
810
- limit : int ,
851
+ has_limit : bool = False ,
811
852
columns : Optional [str ] = CONTENT_COLUMNS ,
812
- metadata : Dict [str , Any ] = {},
813
- embedding : Optional [List [float ]] = None ,
814
- ) -> Tuple [str , Tuple [Any , ...]]:
815
- where_clause , get_cql_vals = self ._extract_where_clause_blocks (
816
- metadata = metadata
853
+ metadata_keys : List [str ] = [],
854
+ has_embedding : bool = False ,
855
+ has_link_from_tags : bool = False ,
856
+ ) -> PreparedStatement :
857
+ where_clause = self ._extract_where_clause_cql (
858
+ metadata_keys = metadata_keys , has_link_from_tags = has_link_from_tags
817
859
)
818
- limit_clause = "LIMIT ?"
819
- limit_cql_vals = [ limit ]
860
+ limit_clause = " LIMIT ?" if has_limit else " "
861
+ order_clause = " ORDER BY text_embedding ANN OF ?" if has_embedding else ""
820
862
821
- order_clause = ""
822
- order_cql_vals = []
823
- if embedding is not None :
824
- order_clause = "ORDER BY text_embedding ANN OF ?"
825
- order_cql_vals = [embedding ]
826
-
827
- select_vals = tuple (list (get_cql_vals ) + order_cql_vals + limit_cql_vals )
828
863
select_cql = SELECT_CQL_TEMPLATE .format (
829
864
columns = columns ,
830
865
table_name = self .table_name (),
831
866
where_clause = where_clause ,
832
867
order_clause = order_clause ,
833
868
limit_clause = limit_clause ,
834
869
)
870
+
871
+ if select_cql in self ._prepared_query_cache :
872
+ return self ._prepared_query_cache [select_cql ]
873
+
835
874
prepared_query = self ._session .prepare (select_cql )
836
875
prepared_query .consistency_level = ConsistencyLevel .ONE
876
+ self ._prepared_query_cache [select_cql ] = prepared_query
837
877
838
- return prepared_query , select_vals
878
+ return prepared_query
879
+
880
+ def _get_search_params (
881
+ self ,
882
+ limit : Optional [int ] = None ,
883
+ metadata : Dict [str , Any ] = {},
884
+ embedding : Optional [List [float ]] = None ,
885
+ link_from_tags : Optional [Tuple [str , str ]] = None ,
886
+ ) -> Tuple [PreparedStatement , Tuple [Any , ...]]:
887
+ where_params = self ._extract_where_clause_params (
888
+ metadata = metadata , link_from_tags = link_from_tags
889
+ )
890
+
891
+ limit_params = [limit ] if limit is not None else []
892
+ order_params = [embedding ] if embedding is not None else []
893
+
894
+ return tuple (list (where_params ) + order_params + limit_params )
895
+
896
+ def _get_search_cql_and_params (
897
+ self ,
898
+ limit : Optional [int ] = None ,
899
+ columns : Optional [str ] = CONTENT_COLUMNS ,
900
+ metadata : Dict [str , Any ] = {},
901
+ embedding : Optional [List [float ]] = None ,
902
+ link_from_tags : Optional [Tuple [str , str ]] = None ,
903
+ ) -> Tuple [PreparedStatement , Tuple [Any , ...]]:
904
+ query = self ._get_search_cql (
905
+ has_limit = limit is not None ,
906
+ columns = columns ,
907
+ metadata_keys = list (metadata .keys ()),
908
+ has_embedding = embedding is not None ,
909
+ has_link_from_tags = link_from_tags is not None ,
910
+ )
911
+ params = self ._get_search_params (
912
+ limit = limit ,
913
+ metadata = metadata ,
914
+ embedding = embedding ,
915
+ link_from_tags = link_from_tags ,
916
+ )
917
+ return query , params
0 commit comments