12
12
Sequence ,
13
13
Set ,
14
14
Tuple ,
15
+ Union ,
15
16
cast ,
16
17
)
17
18
18
- from cassandra .cluster import ConsistencyLevel , Session
19
+ from cassandra .cluster import ConsistencyLevel , Session , ResponseFuture
19
20
from cassio .config import check_resolve_keyspace , check_resolve_session
20
21
21
22
from ._mmr_helper import MmrHelper
@@ -48,6 +49,20 @@ class SetupMode(Enum):
48
49
ASYNC = 2
49
50
OFF = 3
50
51
52
+ class MetadataIndexingMode (Enum ):
53
+ DEFAULT_TO_UNSEARCHABLE = 1
54
+ DEFAULT_TO_SEARCHABLE = 2
55
+
56
+ MetadataIndexingPolicy = Tuple [MetadataIndexingMode , Set [str ]]
57
+
58
+ def _is_metadata_field_indexed (field_name : str , policy : MetadataIndexingPolicy ) -> bool :
59
+ p_mode , p_fields = policy
60
+ if p_mode == MetadataIndexingMode .DEFAULT_TO_UNSEARCHABLE :
61
+ return field_name in p_fields
62
+ elif p_mode == MetadataIndexingMode .DEFAULT_TO_SEARCHABLE :
63
+ return field_name not in p_fields
64
+ else :
65
+ raise ValueError (f"Unexpected metadata indexing mode { p_mode } " )
51
66
52
67
def _serialize_metadata (md : Dict [str , Any ]) -> str :
53
68
if isinstance (md .get ("links" ), Set ):
@@ -90,12 +105,14 @@ def _deserialize_links(json_blob: Optional[str]) -> Set[Link]:
90
105
91
106
92
107
def _row_to_node (row : Any ) -> Node :
93
- metadata = _deserialize_metadata (row .metadata_blob )
108
+ metadata_s = row .get ("metadata_s" , {})
109
+ attributes_blob = row .get ("attributes_blob" )
110
+ attributes_dict = _deserialize_metadata (attributes_blob ) if attributes_blob is not None else {}
94
111
links = _deserialize_links (row .links_blob )
95
112
return Node (
96
113
id = row .content_id ,
97
114
text = row .text_content ,
98
- metadata = metadata ,
115
+ metadata = { ** attributes_dict , ** metadata_s } ,
99
116
links = links ,
100
117
)
101
118
@@ -130,6 +147,7 @@ def __init__(
130
147
session : Optional [Session ] = None ,
131
148
keyspace : Optional [str ] = None ,
132
149
setup_mode : SetupMode = SetupMode .SYNC ,
150
+ metadata_indexing : Union [Tuple [str , Iterable [str ]], str ] = "all" ,
133
151
):
134
152
session = check_resolve_session (session )
135
153
keyspace = check_resolve_keyspace (keyspace )
@@ -145,6 +163,10 @@ def __init__(
145
163
self ._session = session
146
164
self ._keyspace = keyspace
147
165
166
+ self ._metadata_indexing_policy = self ._normalize_metadata_indexing_policy (
167
+ metadata_indexing
168
+ )
169
+
148
170
if setup_mode == SetupMode .SYNC :
149
171
self ._apply_schema ()
150
172
elif setup_mode != SetupMode .OFF :
@@ -158,22 +180,22 @@ def __init__(
158
180
f"""
159
181
INSERT INTO { keyspace } .{ node_table } (
160
182
content_id, kind, text_content, text_embedding, link_to_tags,
161
- link_from_tags, metadata_blob , links_blob
162
- ) VALUES (?, '{ Kind .passage } ', ?, ?, ?, ?, ?, ?)
183
+ link_from_tags, attributes_blob, metadata_s , links_blob
184
+ ) VALUES (?, '{ Kind .passage } ', ?, ?, ?, ?, ?, ?, ? )
163
185
""" # noqa: S608
164
186
)
165
187
166
188
self ._query_by_id = session .prepare (
167
189
f"""
168
- SELECT content_id, kind, text_content, metadata_blob , links_blob
190
+ SELECT content_id, kind, text_content, attributes_blob , links_blob
169
191
FROM { keyspace } .{ node_table }
170
192
WHERE content_id = ?
171
193
""" # noqa: S608
172
194
)
173
195
174
196
self ._query_by_embedding = session .prepare (
175
197
f"""
176
- SELECT content_id, kind, text_content, metadata_blob , links_blob
198
+ SELECT content_id, kind, text_content, attributes_blob , links_blob
177
199
FROM { keyspace } .{ node_table }
178
200
ORDER BY text_embedding ANN OF ?
179
201
LIMIT ?
@@ -255,7 +277,8 @@ def _apply_schema(self) -> None:
255
277
256
278
link_to_tags SET<TUPLE<TEXT, TEXT>>,
257
279
link_from_tags SET<TUPLE<TEXT, TEXT>>,
258
- metadata_blob TEXT,
280
+ attributes_blob TEXT,
281
+ metadata_s MAP<TEXT,TEXT>,
259
282
links_blob TEXT,
260
283
261
284
PRIMARY KEY (content_id)
@@ -275,6 +298,12 @@ def _apply_schema(self) -> None:
275
298
USING 'StorageAttachedIndex';
276
299
""" )
277
300
301
+ self ._session .execute (f"""
302
+ CREATE CUSTOM INDEX IF NOT EXISTS { self ._node_table } _metadata_index
303
+ ON { self ._keyspace } .{ self ._node_table } (ENTRIES(metadata_s))
304
+ USING 'StorageAttachedIndex';
305
+ """ )
306
+
278
307
def _concurrent_queries (self ) -> ConcurrentQueries :
279
308
return ConcurrentQueries (self ._session )
280
309
@@ -313,7 +342,20 @@ def add_nodes(
313
342
if tag .direction in {"out" , "bidir" }:
314
343
link_to_tags .add ((tag .kind , tag .tag ))
315
344
316
- metadata_blob = _serialize_metadata (metadata )
345
+ attributes_dict = {
346
+ k : self ._coerce_string (v )
347
+ for k , v in metadata .items ()
348
+ if not _is_metadata_field_indexed (k , self ._metadata_indexing_policy )
349
+ }
350
+ attributes_blob = _serialize_metadata (attributes_dict )
351
+
352
+ metadata_indexed_dict = {
353
+ k : v
354
+ for k , v in metadata .items ()
355
+ if _is_metadata_field_indexed (k , self ._metadata_indexing_policy )
356
+ }
357
+ metadata_s = {k : self ._coerce_string (v ) for k , v in metadata_indexed_dict .items ()}
358
+
317
359
links_blob = _serialize_links (links )
318
360
cq .execute (
319
361
self ._insert_passage ,
@@ -323,7 +365,8 @@ def add_nodes(
323
365
text_embedding ,
324
366
link_to_tags ,
325
367
link_from_tags ,
326
- metadata_blob ,
368
+ attributes_blob ,
369
+ metadata_s ,
327
370
links_blob ,
328
371
),
329
372
)
@@ -670,3 +713,85 @@ def add_targets(rows: Iterable[Any]) -> None:
670
713
# TODO: Consider a combined limit based on the similarity and/or
671
714
# predicated MMR score?
672
715
return targets .values ()
716
+
717
+ @staticmethod
718
+ def _normalize_metadata_indexing_policy (
719
+ metadata_indexing : Union [Tuple [str , Iterable [str ]], str ]
720
+ ) -> MetadataIndexingPolicy :
721
+ mode : MetadataIndexingMode
722
+ fields : Set [str ]
723
+ # metadata indexing policy normalization:
724
+ if isinstance (metadata_indexing , str ):
725
+ if metadata_indexing .lower () == "all" :
726
+ mode , fields = (MetadataIndexingMode .DEFAULT_TO_SEARCHABLE , set ())
727
+ elif metadata_indexing .lower () == "none" :
728
+ mode , fields = (MetadataIndexingMode .DEFAULT_TO_UNSEARCHABLE , set ())
729
+ else :
730
+ raise ValueError (
731
+ f"Unsupported metadata_indexing value '{ metadata_indexing } '"
732
+ )
733
+ else :
734
+ assert len (metadata_indexing ) == 2
735
+ # it's a 2-tuple (mode, fields) still to normalize
736
+ _mode , _field_spec = metadata_indexing
737
+ fields = {_field_spec } if isinstance (_field_spec , str ) else set (_field_spec )
738
+ if _mode .lower () in {
739
+ "default_to_unsearchable" ,
740
+ "allowlist" ,
741
+ "allow" ,
742
+ "allow_list" ,
743
+ }:
744
+ mode = MetadataIndexingMode .DEFAULT_TO_UNSEARCHABLE
745
+ elif _mode .lower () in {
746
+ "default_to_searchable" ,
747
+ "denylist" ,
748
+ "deny" ,
749
+ "deny_list" ,
750
+ }:
751
+ mode = MetadataIndexingMode .DEFAULT_TO_SEARCHABLE
752
+ else :
753
+ raise ValueError (
754
+ f"Unsupported metadata indexing mode specification '{ _mode } '"
755
+ )
756
+ return (mode , fields )
757
+
758
+ def _split_metadata_fields (self , md_dict : Dict [str , Any ]) -> Dict [str , Any ]:
759
+ """
760
+ Split the *indexed* part of the metadata in separate parts,
761
+ one per Cassandra column.
762
+
763
+ Currently: everything gets cast to a string and goes to a single table
764
+ column. This means:
765
+ - strings are fine
766
+ - floats and integers v: they are cast to str(v)
767
+ - booleans: 'true'/'false' (JSON style)
768
+ - None => 'null' (JSON style)
769
+ - anything else v => str(v), no questions asked
770
+
771
+ Caveat: one gets strings back when reading metadata
772
+ """
773
+
774
+ # TODO: more care about types here
775
+ stringy_part = {k : self ._coerce_string (v ) for k , v in md_dict .items ()}
776
+ return {
777
+ "metadata_s" : stringy_part ,
778
+ }
779
+
780
+ @staticmethod
781
+ def _coerce_string (value : Any ) -> str :
782
+ if isinstance (value , str ):
783
+ return value
784
+ elif isinstance (value , bool ):
785
+ # bool MUST come before int in this chain of ifs!
786
+ return json .dumps (value )
787
+ elif isinstance (value , int ):
788
+ # we don't want to store '1' and '1.0' differently
789
+ # for the sake of metadata-filtered retrieval:
790
+ return json .dumps (float (value ))
791
+ elif isinstance (value , float ):
792
+ return json .dumps (value )
793
+ elif value is None :
794
+ return json .dumps (value )
795
+ else :
796
+ # when all else fails ...
797
+ return str (value )
0 commit comments