12
12
Sequence ,
13
13
Set ,
14
14
Tuple ,
15
+ Union ,
15
16
cast ,
16
17
)
17
18
18
- from cassandra .cluster import ConsistencyLevel , Session
19
+ from cassandra .cluster import ConsistencyLevel , Session , ResponseFuture
19
20
from cassio .config import check_resolve_keyspace , check_resolve_session
20
21
21
22
from ._mmr_helper import MmrHelper
@@ -48,6 +49,20 @@ class SetupMode(Enum):
48
49
ASYNC = 2
49
50
OFF = 3
50
51
52
+ class MetadataIndexingMode (Enum ):
53
+ DEFAULT_TO_UNSEARCHABLE = 1
54
+ DEFAULT_TO_SEARCHABLE = 2
55
+
56
+ MetadataIndexingPolicy = Tuple [MetadataIndexingMode , Set [str ]]
57
+
58
+ def _is_metadata_field_indexed (field_name : str , policy : MetadataIndexingPolicy ) -> bool :
59
+ p_mode , p_fields = policy
60
+ if p_mode == MetadataIndexingMode .DEFAULT_TO_UNSEARCHABLE :
61
+ return field_name in p_fields
62
+ elif p_mode == MetadataIndexingMode .DEFAULT_TO_SEARCHABLE :
63
+ return field_name not in p_fields
64
+ else :
65
+ raise ValueError (f"Unexpected metadata indexing mode { p_mode } " )
51
66
52
67
def _serialize_metadata (md : Dict [str , Any ]) -> str :
53
68
if isinstance (md .get ("links" ), Set ):
@@ -88,12 +103,14 @@ def _deserialize_links(json_blob: Optional[str]) -> Set[Link]:
88
103
89
104
90
105
def _row_to_node (row : Any ) -> Node :
91
- metadata = _deserialize_metadata (row .metadata_blob )
106
+ metadata_s = row .get ("metadata_s" , {})
107
+ attributes_blob = row .get ("attributes_blob" )
108
+ attributes_dict = _deserialize_metadata (attributes_blob ) if attributes_blob is not None else {}
92
109
links = _deserialize_links (row .links_blob )
93
110
return Node (
94
111
id = row .content_id ,
95
112
text = row .text_content ,
96
- metadata = metadata ,
113
+ metadata = { ** attributes_dict , ** metadata_s } ,
97
114
links = links ,
98
115
)
99
116
@@ -128,6 +145,7 @@ def __init__(
128
145
session : Optional [Session ] = None ,
129
146
keyspace : Optional [str ] = None ,
130
147
setup_mode : SetupMode = SetupMode .SYNC ,
148
+ metadata_indexing : Union [Tuple [str , Iterable [str ]], str ] = "all" ,
131
149
):
132
150
session = check_resolve_session (session )
133
151
keyspace = check_resolve_keyspace (keyspace )
@@ -143,6 +161,10 @@ def __init__(
143
161
self ._session = session
144
162
self ._keyspace = keyspace
145
163
164
+ self ._metadata_indexing_policy = self ._normalize_metadata_indexing_policy (
165
+ metadata_indexing
166
+ )
167
+
146
168
if setup_mode == SetupMode .SYNC :
147
169
self ._apply_schema ()
148
170
elif setup_mode != SetupMode .OFF :
@@ -156,22 +178,22 @@ def __init__(
156
178
f"""
157
179
INSERT INTO { keyspace } .{ node_table } (
158
180
content_id, kind, text_content, text_embedding, link_to_tags,
159
- link_from_tags, metadata_blob , links_blob
160
- ) VALUES (?, '{ Kind .passage } ', ?, ?, ?, ?, ?, ?)
181
+ link_from_tags, attributes_blob, metadata_s , links_blob
182
+ ) VALUES (?, '{ Kind .passage } ', ?, ?, ?, ?, ?, ?, ? )
161
183
""" # noqa: S608
162
184
)
163
185
164
186
self ._query_by_id = session .prepare (
165
187
f"""
166
- SELECT content_id, kind, text_content, metadata_blob , links_blob
188
+ SELECT content_id, kind, text_content, attributes_blob , links_blob
167
189
FROM { keyspace } .{ node_table }
168
190
WHERE content_id = ?
169
191
""" # noqa: S608
170
192
)
171
193
172
194
self ._query_by_embedding = session .prepare (
173
195
f"""
174
- SELECT content_id, kind, text_content, metadata_blob , links_blob
196
+ SELECT content_id, kind, text_content, attributes_blob , links_blob
175
197
FROM { keyspace } .{ node_table }
176
198
ORDER BY text_embedding ANN OF ?
177
199
LIMIT ?
@@ -253,7 +275,8 @@ def _apply_schema(self) -> None:
253
275
254
276
link_to_tags SET<TUPLE<TEXT, TEXT>>,
255
277
link_from_tags SET<TUPLE<TEXT, TEXT>>,
256
- metadata_blob TEXT,
278
+ attributes_blob TEXT,
279
+ metadata_s MAP<TEXT,TEXT>,
257
280
links_blob TEXT,
258
281
259
282
PRIMARY KEY (content_id)
@@ -273,6 +296,12 @@ def _apply_schema(self) -> None:
273
296
USING 'StorageAttachedIndex';
274
297
""" )
275
298
299
+ self ._session .execute (f"""
300
+ CREATE CUSTOM INDEX IF NOT EXISTS { self ._node_table } _metadata_index
301
+ ON { self ._keyspace } .{ self ._node_table } (ENTRIES(metadata_s))
302
+ USING 'StorageAttachedIndex';
303
+ """ )
304
+
276
305
def _concurrent_queries (self ) -> ConcurrentQueries :
277
306
return ConcurrentQueries (self ._session )
278
307
@@ -311,7 +340,20 @@ def add_nodes(
311
340
if tag .direction in {"out" , "bidir" }:
312
341
link_to_tags .add ((tag .kind , tag .tag ))
313
342
314
- metadata_blob = _serialize_metadata (metadata )
343
+ attributes_dict = {
344
+ k : self ._coerce_string (v )
345
+ for k , v in metadata .items ()
346
+ if not _is_metadata_field_indexed (k , self ._metadata_indexing_policy )
347
+ }
348
+ attributes_blob = _serialize_metadata (attributes_dict )
349
+
350
+ metadata_indexed_dict = {
351
+ k : v
352
+ for k , v in metadata .items ()
353
+ if _is_metadata_field_indexed (k , self ._metadata_indexing_policy )
354
+ }
355
+ metadata_s = {k : self ._coerce_string (v ) for k , v in metadata_indexed_dict .items ()}
356
+
315
357
links_blob = _serialize_links (links )
316
358
cq .execute (
317
359
self ._insert_passage ,
@@ -321,7 +363,8 @@ def add_nodes(
321
363
text_embedding ,
322
364
link_to_tags ,
323
365
link_from_tags ,
324
- metadata_blob ,
366
+ attributes_blob ,
367
+ metadata_s ,
325
368
links_blob ,
326
369
),
327
370
)
@@ -668,3 +711,85 @@ def add_targets(rows: Iterable[Any]) -> None:
668
711
# TODO: Consider a combined limit based on the similarity and/or
669
712
# predicated MMR score?
670
713
return targets .values ()
714
+
715
+ @staticmethod
716
+ def _normalize_metadata_indexing_policy (
717
+ metadata_indexing : Union [Tuple [str , Iterable [str ]], str ]
718
+ ) -> MetadataIndexingPolicy :
719
+ mode : MetadataIndexingMode
720
+ fields : Set [str ]
721
+ # metadata indexing policy normalization:
722
+ if isinstance (metadata_indexing , str ):
723
+ if metadata_indexing .lower () == "all" :
724
+ mode , fields = (MetadataIndexingMode .DEFAULT_TO_SEARCHABLE , set ())
725
+ elif metadata_indexing .lower () == "none" :
726
+ mode , fields = (MetadataIndexingMode .DEFAULT_TO_UNSEARCHABLE , set ())
727
+ else :
728
+ raise ValueError (
729
+ f"Unsupported metadata_indexing value '{ metadata_indexing } '"
730
+ )
731
+ else :
732
+ assert len (metadata_indexing ) == 2
733
+ # it's a 2-tuple (mode, fields) still to normalize
734
+ _mode , _field_spec = metadata_indexing
735
+ fields = {_field_spec } if isinstance (_field_spec , str ) else set (_field_spec )
736
+ if _mode .lower () in {
737
+ "default_to_unsearchable" ,
738
+ "allowlist" ,
739
+ "allow" ,
740
+ "allow_list" ,
741
+ }:
742
+ mode = MetadataIndexingMode .DEFAULT_TO_UNSEARCHABLE
743
+ elif _mode .lower () in {
744
+ "default_to_searchable" ,
745
+ "denylist" ,
746
+ "deny" ,
747
+ "deny_list" ,
748
+ }:
749
+ mode = MetadataIndexingMode .DEFAULT_TO_SEARCHABLE
750
+ else :
751
+ raise ValueError (
752
+ f"Unsupported metadata indexing mode specification '{ _mode } '"
753
+ )
754
+ return (mode , fields )
755
+
756
+ def _split_metadata_fields (self , md_dict : Dict [str , Any ]) -> Dict [str , Any ]:
757
+ """
758
+ Split the *indexed* part of the metadata in separate parts,
759
+ one per Cassandra column.
760
+
761
+ Currently: everything gets cast to a string and goes to a single table
762
+ column. This means:
763
+ - strings are fine
764
+ - floats and integers v: they are cast to str(v)
765
+ - booleans: 'true'/'false' (JSON style)
766
+ - None => 'null' (JSON style)
767
+ - anything else v => str(v), no questions asked
768
+
769
+ Caveat: one gets strings back when reading metadata
770
+ """
771
+
772
+ # TODO: more care about types here
773
+ stringy_part = {k : self ._coerce_string (v ) for k , v in md_dict .items ()}
774
+ return {
775
+ "metadata_s" : stringy_part ,
776
+ }
777
+
778
+ @staticmethod
779
+ def _coerce_string (value : Any ) -> str :
780
+ if isinstance (value , str ):
781
+ return value
782
+ elif isinstance (value , bool ):
783
+ # bool MUST come before int in this chain of ifs!
784
+ return json .dumps (value )
785
+ elif isinstance (value , int ):
786
+ # we don't want to store '1' and '1.0' differently
787
+ # for the sake of metadata-filtered retrieval:
788
+ return json .dumps (float (value ))
789
+ elif isinstance (value , float ):
790
+ return json .dumps (value )
791
+ elif value is None :
792
+ return json .dumps (value )
793
+ else :
794
+ # when all else fails ...
795
+ return str (value )
0 commit comments