27
27
28
28
CONTENT_ID = "content_id"
29
29
30
+ CONTENT_COLUMNS = "content_id, kind, text_content, attributes_blob, metadata_s, links_blob"
31
+
32
+ SELECT_CQL_TEMPLATE = "SELECT {columns} FROM {table_name} {where_clause} {limit_clause};"
30
33
31
34
@dataclass
32
35
class Node :
@@ -105,8 +108,10 @@ def _deserialize_links(json_blob: Optional[str]) -> Set[Link]:
105
108
106
109
107
110
def _row_to_node (row : Any ) -> Node :
108
- metadata_s = row .get ("metadata_s" , {})
109
- attributes_blob = row .get ("attributes_blob" )
111
+ metadata_s = row .metadata_s
112
+ if metadata_s is None :
113
+ metadata_s = {}
114
+ attributes_blob = row .attributes_blob
110
115
attributes_dict = _deserialize_metadata (attributes_blob ) if attributes_blob is not None else {}
111
116
links = _deserialize_links (row .links_blob )
112
117
return Node (
@@ -164,7 +169,7 @@ def __init__(
164
169
self ._keyspace = keyspace
165
170
166
171
self ._metadata_indexing_policy = self ._normalize_metadata_indexing_policy (
167
- metadata_indexing
172
+ metadata_indexing = metadata_indexing ,
168
173
)
169
174
170
175
if setup_mode == SetupMode .SYNC :
@@ -187,15 +192,15 @@ def __init__(
187
192
188
193
self ._query_by_id = session .prepare (
189
194
f"""
190
- SELECT content_id, kind, text_content, attributes_blob, links_blob
195
+ SELECT { CONTENT_COLUMNS }
191
196
FROM { keyspace } .{ node_table }
192
197
WHERE content_id = ?
193
198
""" # noqa: S608
194
199
)
195
200
196
201
self ._query_by_embedding = session .prepare (
197
202
f"""
198
- SELECT content_id, kind, text_content, attributes_blob, links_blob
203
+ SELECT { CONTENT_COLUMNS }
199
204
FROM { keyspace } .{ node_table }
200
205
ORDER BY text_embedding ANN OF ?
201
206
LIMIT ?
@@ -307,6 +312,25 @@ def _apply_schema(self) -> None:
307
312
def _concurrent_queries (self ) -> ConcurrentQueries :
308
313
return ConcurrentQueries (self ._session )
309
314
315
+ def _parse_metadata (self , metadata : Dict [str , Any ], is_query : bool ) -> Tuple [str , Dict [str ,str ]]:
316
+ attributes_dict = {
317
+ k : self ._coerce_string (v )
318
+ for k , v in metadata .items ()
319
+ if not _is_metadata_field_indexed (k , self ._metadata_indexing_policy )
320
+ }
321
+ if is_query and len (attributes_dict ) > 0 :
322
+ raise ValueError ("Non-indexed metadata fields cannot be used in queries." )
323
+ attributes_blob = _serialize_metadata (attributes_dict )
324
+
325
+ metadata_indexed_dict = {
326
+ k : v
327
+ for k , v in metadata .items ()
328
+ if _is_metadata_field_indexed (k , self ._metadata_indexing_policy )
329
+ }
330
+ metadata_s = {k : self ._coerce_string (v ) for k , v in metadata_indexed_dict .items ()}
331
+ return (attributes_blob , metadata_s )
332
+
333
+
310
334
# TODO: Async (aadd_nodes)
311
335
def add_nodes (
312
336
self ,
@@ -342,19 +366,7 @@ def add_nodes(
342
366
if tag .direction in {"out" , "bidir" }:
343
367
link_to_tags .add ((tag .kind , tag .tag ))
344
368
345
- attributes_dict = {
346
- k : self ._coerce_string (v )
347
- for k , v in metadata .items ()
348
- if not _is_metadata_field_indexed (k , self ._metadata_indexing_policy )
349
- }
350
- attributes_blob = _serialize_metadata (attributes_dict )
351
-
352
- metadata_indexed_dict = {
353
- k : v
354
- for k , v in metadata .items ()
355
- if _is_metadata_field_indexed (k , self ._metadata_indexing_policy )
356
- }
357
- metadata_s = {k : self ._coerce_string (v ) for k , v in metadata_indexed_dict .items ()}
369
+ attributes_blob , metadata_s = self ._parse_metadata (metadata = metadata , is_query = False )
358
370
359
371
links_blob = _serialize_links (links )
360
372
cq .execute (
@@ -380,7 +392,7 @@ def _nodes_with_ids(
380
392
results : Dict [str , Optional [Node ]] = {}
381
393
with self ._concurrent_queries () as cq :
382
394
383
- def add_nodes (rows : Iterable [Any ]) -> None :
395
+ def node_callback (rows : Iterable [Any ]) -> None :
384
396
# Should always be exactly one row here. We don't need to check
385
397
# 1. The query is for a `ID == ?` query on the primary key.
386
398
# 2. If it doesn't exist, the `get_result` method below will
@@ -393,7 +405,7 @@ def add_nodes(rows: Iterable[Any]) -> None:
393
405
# Mark this node ID as being fetched.
394
406
results [node_id ] = None
395
407
cq .execute (
396
- self ._query_by_id , parameters = (node_id ,), callback = add_nodes
408
+ self ._query_by_id , parameters = (node_id ,), callback = node_callback
397
409
)
398
410
399
411
def get_result (node_id : str ) -> Node :
@@ -643,6 +655,18 @@ def similarity_search(
643
655
for row in self ._session .execute (self ._query_by_embedding , (embedding , k )):
644
656
yield _row_to_node (row )
645
657
658
+ def metadata_search (self , metadata : Dict [str , Any ] = {}, n : Optional [int ] = 5 )-> Iterable [Node ]:
659
+ query , params = self ._get_metadata_search_cql (metadata = metadata , n = n )
660
+
661
+ prepared_query = self ._session .prepare (query )
662
+
663
+ for row in self ._session .execute (prepared_query , params ):
664
+ yield _row_to_node (row )
665
+
666
+ def get_node (self , id : str ) -> Node :
667
+ return self ._nodes_with_ids (ids = [id ])[0 ]
668
+
669
+
646
670
def _get_outgoing_tags (
647
671
self ,
648
672
source_ids : Iterable [str ],
@@ -755,28 +779,6 @@ def _normalize_metadata_indexing_policy(
755
779
)
756
780
return (mode , fields )
757
781
758
- def _split_metadata_fields (self , md_dict : Dict [str , Any ]) -> Dict [str , Any ]:
759
- """
760
- Split the *indexed* part of the metadata in separate parts,
761
- one per Cassandra column.
762
-
763
- Currently: everything gets cast to a string and goes to a single table
764
- column. This means:
765
- - strings are fine
766
- - floats and integers v: they are cast to str(v)
767
- - booleans: 'true'/'false' (JSON style)
768
- - None => 'null' (JSON style)
769
- - anything else v => str(v), no questions asked
770
-
771
- Caveat: one gets strings back when reading metadata
772
- """
773
-
774
- # TODO: more care about types here
775
- stringy_part = {k : self ._coerce_string (v ) for k , v in md_dict .items ()}
776
- return {
777
- "metadata_s" : stringy_part ,
778
- }
779
-
780
782
@staticmethod
781
783
def _coerce_string (value : Any ) -> str :
782
784
if isinstance (value , str ):
@@ -794,4 +796,39 @@ def _coerce_string(value: Any) -> str:
794
796
return json .dumps (value )
795
797
else :
796
798
# when all else fails ...
797
- return str (value )
799
+ return str (value )
800
+
801
+ def _extract_where_clause_blocks (
802
+ self , metadata : Dict [str , Any ]
803
+ ) -> Tuple [str , List [Any ]]:
804
+
805
+ attributes_blob , metadata_s = self ._parse_metadata (metadata = metadata , is_query = True )
806
+
807
+ if len (metadata_s ) == 0 :
808
+ return "" , []
809
+
810
+ wc_blocks : List [str ] = []
811
+ vals_list : List [Any ] = []
812
+
813
+ for k , v in sorted (metadata_s .items ()):
814
+ wc_blocks .append (f"metadata_s['{ k } '] = ?" )
815
+ vals_list .append (v )
816
+
817
+ where_clause = "WHERE " + " AND " .join (wc_blocks )
818
+ return where_clause , vals_list
819
+
820
+
821
+ def _get_metadata_search_cql (self , n : int , metadata : Dict [str , Any ]) -> Tuple [str , Tuple [Any , ...]]:
822
+ where_clause , get_cql_vals = self ._extract_where_clause_blocks (metadata = metadata )
823
+ limit_clause = "LIMIT ?"
824
+ limit_cql_vals = [n ]
825
+ select_vals = tuple (list (get_cql_vals ) + limit_cql_vals )
826
+ #
827
+ select_cql = SELECT_CQL_TEMPLATE .format (
828
+ columns = CONTENT_COLUMNS ,
829
+ table_name = f"{ self ._keyspace } .{ self ._node_table } " ,
830
+ where_clause = where_clause ,
831
+ limit_clause = limit_clause ,
832
+
833
+ )
834
+ return select_cql , select_vals
0 commit comments