Skip to content

Commit cde6f91

Browse files
committed
added metadata index put and get support
1 parent 02e46e6 commit cde6f91

File tree

1 file changed

+135
-10
lines changed

1 file changed

+135
-10
lines changed

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

Lines changed: 135 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
Sequence,
1313
Set,
1414
Tuple,
15+
Union,
1516
cast,
1617
)
1718

18-
from cassandra.cluster import ConsistencyLevel, Session
19+
from cassandra.cluster import ConsistencyLevel, Session, ResponseFuture
1920
from cassio.config import check_resolve_keyspace, check_resolve_session
2021

2122
from ._mmr_helper import MmrHelper
@@ -48,6 +49,20 @@ class SetupMode(Enum):
4849
ASYNC = 2
4950
OFF = 3
5051

52+
class MetadataIndexingMode(Enum):
53+
DEFAULT_TO_UNSEARCHABLE = 1
54+
DEFAULT_TO_SEARCHABLE = 2
55+
56+
MetadataIndexingPolicy = Tuple[MetadataIndexingMode, Set[str]]
57+
58+
def _is_metadata_field_indexed(field_name: str, policy: MetadataIndexingPolicy) -> bool:
59+
p_mode, p_fields = policy
60+
if p_mode == MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE:
61+
return field_name in p_fields
62+
elif p_mode == MetadataIndexingMode.DEFAULT_TO_SEARCHABLE:
63+
return field_name not in p_fields
64+
else:
65+
raise ValueError(f"Unexpected metadata indexing mode {p_mode}")
5166

5267
def _serialize_metadata(md: Dict[str, Any]) -> str:
5368
if isinstance(md.get("links"), Set):
@@ -90,12 +105,14 @@ def _deserialize_links(json_blob: Optional[str]) -> Set[Link]:
90105

91106

92107
def _row_to_node(row: Any) -> Node:
93-
metadata = _deserialize_metadata(row.metadata_blob)
108+
metadata_s = row.get("metadata_s", {})
109+
attributes_blob = row.get("attributes_blob")
110+
attributes_dict = _deserialize_metadata(attributes_blob) if attributes_blob is not None else {}
94111
links = _deserialize_links(row.links_blob)
95112
return Node(
96113
id=row.content_id,
97114
text=row.text_content,
98-
metadata=metadata,
115+
metadata={**attributes_dict, **metadata_s},
99116
links=links,
100117
)
101118

@@ -130,6 +147,7 @@ def __init__(
130147
session: Optional[Session] = None,
131148
keyspace: Optional[str] = None,
132149
setup_mode: SetupMode = SetupMode.SYNC,
150+
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
133151
):
134152
session = check_resolve_session(session)
135153
keyspace = check_resolve_keyspace(keyspace)
@@ -145,6 +163,10 @@ def __init__(
145163
self._session = session
146164
self._keyspace = keyspace
147165

166+
self._metadata_indexing_policy = self._normalize_metadata_indexing_policy(
167+
metadata_indexing
168+
)
169+
148170
if setup_mode == SetupMode.SYNC:
149171
self._apply_schema()
150172
elif setup_mode != SetupMode.OFF:
@@ -158,22 +180,22 @@ def __init__(
158180
f"""
159181
INSERT INTO {keyspace}.{node_table} (
160182
content_id, kind, text_content, text_embedding, link_to_tags,
161-
link_from_tags, metadata_blob, links_blob
162-
) VALUES (?, '{Kind.passage}', ?, ?, ?, ?, ?, ?)
183+
link_from_tags, attributes_blob, metadata_s, links_blob
184+
) VALUES (?, '{Kind.passage}', ?, ?, ?, ?, ?, ?, ?)
163185
""" # noqa: S608
164186
)
165187

166188
self._query_by_id = session.prepare(
167189
f"""
168-
SELECT content_id, kind, text_content, metadata_blob, links_blob
190+
SELECT content_id, kind, text_content, attributes_blob, links_blob
169191
FROM {keyspace}.{node_table}
170192
WHERE content_id = ?
171193
""" # noqa: S608
172194
)
173195

174196
self._query_by_embedding = session.prepare(
175197
f"""
176-
SELECT content_id, kind, text_content, metadata_blob, links_blob
198+
SELECT content_id, kind, text_content, attributes_blob, links_blob
177199
FROM {keyspace}.{node_table}
178200
ORDER BY text_embedding ANN OF ?
179201
LIMIT ?
@@ -255,7 +277,8 @@ def _apply_schema(self) -> None:
255277
256278
link_to_tags SET<TUPLE<TEXT, TEXT>>,
257279
link_from_tags SET<TUPLE<TEXT, TEXT>>,
258-
metadata_blob TEXT,
280+
attributes_blob TEXT,
281+
metadata_s MAP<TEXT,TEXT>,
259282
links_blob TEXT,
260283
261284
PRIMARY KEY (content_id)
@@ -275,6 +298,12 @@ def _apply_schema(self) -> None:
275298
USING 'StorageAttachedIndex';
276299
""")
277300

301+
self._session.execute(f"""
302+
CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_metadata_index
303+
ON {self._keyspace}.{self._node_table}(ENTRIES(metadata_s))
304+
USING 'StorageAttachedIndex';
305+
""")
306+
278307
def _concurrent_queries(self) -> ConcurrentQueries:
279308
return ConcurrentQueries(self._session)
280309

@@ -313,7 +342,20 @@ def add_nodes(
313342
if tag.direction in {"out", "bidir"}:
314343
link_to_tags.add((tag.kind, tag.tag))
315344

316-
metadata_blob = _serialize_metadata(metadata)
345+
attributes_dict = {
346+
k: self._coerce_string(v)
347+
for k, v in metadata.items()
348+
if not _is_metadata_field_indexed(k, self._metadata_indexing_policy)
349+
}
350+
attributes_blob = _serialize_metadata(attributes_dict)
351+
352+
metadata_indexed_dict = {
353+
k: v
354+
for k, v in metadata.items()
355+
if _is_metadata_field_indexed(k, self._metadata_indexing_policy)
356+
}
357+
metadata_s = {k: self._coerce_string(v) for k, v in metadata_indexed_dict.items()}
358+
317359
links_blob = _serialize_links(links)
318360
cq.execute(
319361
self._insert_passage,
@@ -323,7 +365,8 @@ def add_nodes(
323365
text_embedding,
324366
link_to_tags,
325367
link_from_tags,
326-
metadata_blob,
368+
attributes_blob,
369+
metadata_s,
327370
links_blob,
328371
),
329372
)
@@ -670,3 +713,85 @@ def add_targets(rows: Iterable[Any]) -> None:
670713
# TODO: Consider a combined limit based on the similarity and/or
671714
# predicated MMR score?
672715
return targets.values()
716+
717+
@staticmethod
718+
def _normalize_metadata_indexing_policy(
719+
metadata_indexing: Union[Tuple[str, Iterable[str]], str]
720+
) -> MetadataIndexingPolicy:
721+
mode: MetadataIndexingMode
722+
fields: Set[str]
723+
# metadata indexing policy normalization:
724+
if isinstance(metadata_indexing, str):
725+
if metadata_indexing.lower() == "all":
726+
mode, fields = (MetadataIndexingMode.DEFAULT_TO_SEARCHABLE, set())
727+
elif metadata_indexing.lower() == "none":
728+
mode, fields = (MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE, set())
729+
else:
730+
raise ValueError(
731+
f"Unsupported metadata_indexing value '{metadata_indexing}'"
732+
)
733+
else:
734+
assert len(metadata_indexing) == 2
735+
# it's a 2-tuple (mode, fields) still to normalize
736+
_mode, _field_spec = metadata_indexing
737+
fields = {_field_spec} if isinstance(_field_spec, str) else set(_field_spec)
738+
if _mode.lower() in {
739+
"default_to_unsearchable",
740+
"allowlist",
741+
"allow",
742+
"allow_list",
743+
}:
744+
mode = MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE
745+
elif _mode.lower() in {
746+
"default_to_searchable",
747+
"denylist",
748+
"deny",
749+
"deny_list",
750+
}:
751+
mode = MetadataIndexingMode.DEFAULT_TO_SEARCHABLE
752+
else:
753+
raise ValueError(
754+
f"Unsupported metadata indexing mode specification '{_mode}'"
755+
)
756+
return (mode, fields)
757+
758+
def _split_metadata_fields(self, md_dict: Dict[str, Any]) -> Dict[str, Any]:
759+
"""
760+
Split the *indexed* part of the metadata in separate parts,
761+
one per Cassandra column.
762+
763+
Currently: everything gets cast to a string and goes to a single table
764+
column. This means:
765+
- strings are fine
766+
- floats and integers v: they are cast to str(v)
767+
- booleans: 'true'/'false' (JSON style)
768+
- None => 'null' (JSON style)
769+
- anything else v => str(v), no questions asked
770+
771+
Caveat: one gets strings back when reading metadata
772+
"""
773+
774+
# TODO: more care about types here
775+
stringy_part = {k: self._coerce_string(v) for k, v in md_dict.items()}
776+
return {
777+
"metadata_s": stringy_part,
778+
}
779+
780+
@staticmethod
781+
def _coerce_string(value: Any) -> str:
782+
if isinstance(value, str):
783+
return value
784+
elif isinstance(value, bool):
785+
# bool MUST come before int in this chain of ifs!
786+
return json.dumps(value)
787+
elif isinstance(value, int):
788+
# we don't want to store '1' and '1.0' differently
789+
# for the sake of metadata-filtered retrieval:
790+
return json.dumps(float(value))
791+
elif isinstance(value, float):
792+
return json.dumps(value)
793+
elif value is None:
794+
return json.dumps(value)
795+
else:
796+
# when all else fails ...
797+
return str(value)

0 commit comments

Comments
 (0)