1
+ # ruff: noqa: B006
2
+
1
3
import json
2
4
import re
3
5
import secrets
16
18
cast ,
17
19
)
18
20
19
- from cassandra .cluster import ConsistencyLevel , Session , ResponseFuture
21
+ from cassandra .cluster import ConsistencyLevel , Session
20
22
from cassio .config import check_resolve_keyspace , check_resolve_session
21
23
22
24
from ._mmr_helper import MmrHelper
27
29
28
30
CONTENT_ID = "content_id"
29
31
30
- CONTENT_COLUMNS = "content_id, kind, text_content, attributes_blob, metadata_s, links_blob"
32
+ CONTENT_COLUMNS = (
33
+ "content_id, kind, text_content, attributes_blob, metadata_s, links_blob"
34
+ )
35
+
36
+ SELECT_CQL_TEMPLATE = (
37
+ "SELECT {columns} FROM {table_name} {where_clause} {order_clause} {limit_clause};"
38
+ )
31
39
32
- SELECT_CQL_TEMPLATE = "SELECT {columns} FROM {table_name} {where_clause} {order_clause} {limit_clause};"
33
40
34
41
@dataclass
35
42
class Node :
@@ -52,20 +59,25 @@ class SetupMode(Enum):
52
59
ASYNC = 2
53
60
OFF = 3
54
61
62
+
55
63
class MetadataIndexingMode (Enum ):
64
+ """Mode used to index metadata."""
65
+
56
66
DEFAULT_TO_UNSEARCHABLE = 1
57
67
DEFAULT_TO_SEARCHABLE = 2
58
68
69
+
59
70
MetadataIndexingPolicy = Tuple [MetadataIndexingMode , Set [str ]]
60
71
72
+
61
73
def _is_metadata_field_indexed (field_name : str , policy : MetadataIndexingPolicy ) -> bool :
62
74
p_mode , p_fields = policy
63
75
if p_mode == MetadataIndexingMode .DEFAULT_TO_UNSEARCHABLE :
64
76
return field_name in p_fields
65
- elif p_mode == MetadataIndexingMode .DEFAULT_TO_SEARCHABLE :
77
+ if p_mode == MetadataIndexingMode .DEFAULT_TO_SEARCHABLE :
66
78
return field_name not in p_fields
67
- else :
68
- raise ValueError ( f"Unexpected metadata indexing mode { p_mode } " )
79
+ raise ValueError ( f"Unexpected metadata indexing mode { p_mode } " )
80
+
69
81
70
82
def _serialize_metadata (md : Dict [str , Any ]) -> str :
71
83
if isinstance (md .get ("links" ), Set ):
@@ -112,7 +124,9 @@ def _row_to_node(row: Any) -> Node:
112
124
if metadata_s is None :
113
125
metadata_s = {}
114
126
attributes_blob = row .attributes_blob
115
- attributes_dict = _deserialize_metadata (attributes_blob ) if attributes_blob is not None else {}
127
+ attributes_dict = (
128
+ _deserialize_metadata (attributes_blob ) if attributes_blob is not None else {}
129
+ )
116
130
links = _deserialize_links (row .links_blob )
117
131
return Node (
118
132
id = row .content_id ,
@@ -237,6 +251,7 @@ def __init__(
237
251
)
238
252
239
253
def table_name (self ) -> str :
254
+ """Returns the fully qualified table name."""
240
255
return f"{ self ._keyspace } .{ self ._node_table } "
241
256
242
257
def _apply_schema (self ) -> None :
@@ -281,7 +296,9 @@ def _apply_schema(self) -> None:
281
296
def _concurrent_queries (self ) -> ConcurrentQueries :
282
297
return ConcurrentQueries (self ._session )
283
298
284
- def _parse_metadata (self , metadata : Dict [str , Any ], is_query : bool ) -> Tuple [str , Dict [str ,str ]]:
299
+ def _parse_metadata (
300
+ self , metadata : Dict [str , Any ], is_query : bool
301
+ ) -> Tuple [str , Dict [str , str ]]:
285
302
attributes_dict = {
286
303
k : self ._coerce_string (v )
287
304
for k , v in metadata .items ()
@@ -296,10 +313,11 @@ def _parse_metadata(self, metadata: Dict[str, Any], is_query: bool) -> Tuple[str
296
313
for k , v in metadata .items ()
297
314
if _is_metadata_field_indexed (k , self ._metadata_indexing_policy )
298
315
}
299
- metadata_s = {k : self ._coerce_string (v ) for k , v in metadata_indexed_dict .items ()}
316
+ metadata_s = {
317
+ k : self ._coerce_string (v ) for k , v in metadata_indexed_dict .items ()
318
+ }
300
319
return (attributes_blob , metadata_s )
301
320
302
-
303
321
# TODO: Async (aadd_nodes)
304
322
def add_nodes (
305
323
self ,
@@ -335,7 +353,9 @@ def add_nodes(
335
353
if tag .direction in {"out" , "bidir" }:
336
354
link_to_tags .add ((tag .kind , tag .tag ))
337
355
338
- attributes_blob , metadata_s = self ._parse_metadata (metadata = metadata , is_query = False )
356
+ attributes_blob , metadata_s = self ._parse_metadata (
357
+ metadata = metadata , is_query = False
358
+ )
339
359
340
360
links_blob = _serialize_links (links )
341
361
cq .execute (
@@ -440,7 +460,7 @@ def fetch_initial_candidates() -> None:
440
460
limit = fetch_k ,
441
461
columns = "content_id, text_embedding, link_to_tags" ,
442
462
metadata = metadata ,
443
- embedding = query_embedding
463
+ embedding = query_embedding ,
444
464
)
445
465
446
466
fetched = self ._session .execute (query = query , parameters = params )
@@ -515,7 +535,12 @@ def fetch_initial_candidates() -> None:
515
535
return self ._nodes_with_ids (helper .selected_ids )
516
536
517
537
def traversal_search (
518
- self , query : str , * , k : int = 4 , depth : int = 1 , metadata : Optional [Dict [str , Any ]] = [],
538
+ self ,
539
+ query : str ,
540
+ * ,
541
+ k : int = 4 ,
542
+ depth : int = 1 ,
543
+ metadata : Optional [Dict [str , Any ]] = [],
519
544
) -> Iterable [Node ]:
520
545
"""Retrieve documents from this knowledge store.
521
546
@@ -634,21 +659,26 @@ def similarity_search(
634
659
k : int = 4 ,
635
660
metadata : Optional [Dict [str , Any ]] = [],
636
661
) -> Iterable [Node ]:
637
- """Retrieve nodes similar to the given embedding, optionally filtered by metadata"""
638
- query , params = self ._get_search_cql (embedding = embedding , limit = k , metadata = metadata )
662
+ """Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501
663
+ query , params = self ._get_search_cql (
664
+ embedding = embedding , limit = k , metadata = metadata
665
+ )
639
666
640
667
for row in self ._session .execute (query , params ):
641
668
yield _row_to_node (row )
642
669
643
- def metadata_search (self , metadata : Dict [str , Any ] = {}, n : Optional [int ] = 5 )-> Iterable [Node ]:
670
+ def metadata_search (
671
+ self , metadata : Dict [str , Any ] = {}, n : Optional [int ] = 5
672
+ ) -> Iterable [Node ]:
673
+ """Retrieve nodes based on their metadata."""
644
674
query , params = self ._get_search_cql (metadata = metadata , limit = n )
645
675
646
676
for row in self ._session .execute (query , params ):
647
677
yield _row_to_node (row )
648
678
649
- def get_node (self , id : str ) -> Node :
650
- return self . _nodes_with_ids ( ids = [ id ])[ 0 ]
651
-
679
+ def get_node (self , content_id : str ) -> Node :
680
+ """Get a node by its id."""
681
+ return self . _nodes_with_ids ( ids = [ content_id ])[ 0 ]
652
682
653
683
def _get_outgoing_tags (
654
684
self ,
@@ -723,7 +753,7 @@ def add_targets(rows: Iterable[Any]) -> None:
723
753
724
754
@staticmethod
725
755
def _normalize_metadata_indexing_policy (
726
- metadata_indexing : Union [Tuple [str , Iterable [str ]], str ]
756
+ metadata_indexing : Union [Tuple [str , Iterable [str ]], str ],
727
757
) -> MetadataIndexingPolicy :
728
758
mode : MetadataIndexingMode
729
759
fields : Set [str ]
@@ -738,7 +768,10 @@ def _normalize_metadata_indexing_policy(
738
768
f"Unsupported metadata_indexing value '{ metadata_indexing } '"
739
769
)
740
770
else :
741
- assert len (metadata_indexing ) == 2
771
+ if len (metadata_indexing ) != 2 : # noqa: PLR2004
772
+ raise ValueError (
773
+ f"Unsupported metadata_indexing value '{ metadata_indexing } '."
774
+ )
742
775
# it's a 2-tuple (mode, fields) still to normalize
743
776
_mode , _field_spec = metadata_indexing
744
777
fields = {_field_spec } if isinstance (_field_spec , str ) else set (_field_spec )
@@ -766,25 +799,21 @@ def _normalize_metadata_indexing_policy(
766
799
def _coerce_string (value : Any ) -> str :
767
800
if isinstance (value , str ):
768
801
return value
769
- elif isinstance (value , bool ):
802
+ if isinstance (value , bool ):
770
803
# bool MUST come before int in this chain of ifs!
771
804
return json .dumps (value )
772
- elif isinstance (value , int ):
805
+ if isinstance (value , int ):
773
806
# we don't want to store '1' and '1.0' differently
774
807
# for the sake of metadata-filtered retrieval:
775
808
return json .dumps (float (value ))
776
- elif isinstance (value , float ):
809
+ if isinstance (value , float ) or value is None :
777
810
return json .dumps (value )
778
- elif value is None :
779
- return json .dumps (value )
780
- else :
781
- # when all else fails ...
782
- return str (value )
811
+ # when all else fails ...
812
+ return str (value )
783
813
784
814
def _extract_where_clause_blocks (
785
815
self , metadata : Dict [str , Any ]
786
816
) -> Tuple [str , List [Any ]]:
787
-
788
817
_ , metadata_s = self ._parse_metadata (metadata = metadata , is_query = True )
789
818
790
819
if len (metadata_s ) == 0 :
@@ -800,13 +829,20 @@ def _extract_where_clause_blocks(
800
829
where_clause = "WHERE " + " AND " .join (wc_blocks )
801
830
return where_clause , vals_list
802
831
803
-
804
- def _get_search_cql (self , limit : int , columns : Optional [str ] = CONTENT_COLUMNS , metadata : Optional [Dict [str , Any ]] = {}, embedding : Optional [List [float ]] = None ) -> Tuple [str , Tuple [Any , ...]]:
805
- where_clause , get_cql_vals = self ._extract_where_clause_blocks (metadata = metadata )
832
+ def _get_search_cql (
833
+ self ,
834
+ limit : int ,
835
+ columns : Optional [str ] = CONTENT_COLUMNS ,
836
+ metadata : Optional [Dict [str , Any ]] = {},
837
+ embedding : Optional [List [float ]] = None ,
838
+ ) -> Tuple [str , Tuple [Any , ...]]:
839
+ where_clause , get_cql_vals = self ._extract_where_clause_blocks (
840
+ metadata = metadata
841
+ )
806
842
limit_clause = "LIMIT ?"
807
843
limit_cql_vals = [limit ]
808
844
809
- order_clause = ""
845
+ order_clause = ""
810
846
order_cql_vals = []
811
847
if embedding is not None :
812
848
order_clause = "ORDER BY text_embedding ANN OF ?"
@@ -819,7 +855,6 @@ def _get_search_cql(self, limit: int, columns: Optional[str] = CONTENT_COLUMNS,
819
855
where_clause = where_clause ,
820
856
order_clause = order_clause ,
821
857
limit_clause = limit_clause ,
822
-
823
858
)
824
859
prepared_query = self ._session .prepare (select_cql )
825
860
prepared_query .consistency_level = ConsistencyLevel .ONE
0 commit comments