Skip to content

Commit 8b5360c

Browse files
committed
added metadata index put and get support
1 parent 1012c3e commit 8b5360c

File tree

3 files changed

+195
-10
lines changed

3 files changed

+195
-10
lines changed

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

Lines changed: 135 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
Sequence,
1313
Set,
1414
Tuple,
15+
Union,
1516
cast,
1617
)
1718

18-
from cassandra.cluster import ConsistencyLevel, Session
19+
from cassandra.cluster import ConsistencyLevel, Session, ResponseFuture
1920
from cassio.config import check_resolve_keyspace, check_resolve_session
2021

2122
from ._mmr_helper import MmrHelper
@@ -48,6 +49,20 @@ class SetupMode(Enum):
4849
ASYNC = 2
4950
OFF = 3
5051

52+
class MetadataIndexingMode(Enum):
53+
DEFAULT_TO_UNSEARCHABLE = 1
54+
DEFAULT_TO_SEARCHABLE = 2
55+
56+
MetadataIndexingPolicy = Tuple[MetadataIndexingMode, Set[str]]
57+
58+
def _is_metadata_field_indexed(field_name: str, policy: MetadataIndexingPolicy) -> bool:
59+
p_mode, p_fields = policy
60+
if p_mode == MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE:
61+
return field_name in p_fields
62+
elif p_mode == MetadataIndexingMode.DEFAULT_TO_SEARCHABLE:
63+
return field_name not in p_fields
64+
else:
65+
raise ValueError(f"Unexpected metadata indexing mode {p_mode}")
5166

5267
def _serialize_metadata(md: Dict[str, Any]) -> str:
5368
if isinstance(md.get("links"), Set):
@@ -88,12 +103,14 @@ def _deserialize_links(json_blob: Optional[str]) -> Set[Link]:
88103

89104

90105
def _row_to_node(row: Any) -> Node:
91-
metadata = _deserialize_metadata(row.metadata_blob)
106+
metadata_s = row.get("metadata_s", {})
107+
attributes_blob = row.get("attributes_blob")
108+
attributes_dict = _deserialize_metadata(attributes_blob) if attributes_blob is not None else {}
92109
links = _deserialize_links(row.links_blob)
93110
return Node(
94111
id=row.content_id,
95112
text=row.text_content,
96-
metadata=metadata,
113+
metadata={**attributes_dict, **metadata_s},
97114
links=links,
98115
)
99116

@@ -128,6 +145,7 @@ def __init__(
128145
session: Optional[Session] = None,
129146
keyspace: Optional[str] = None,
130147
setup_mode: SetupMode = SetupMode.SYNC,
148+
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
131149
):
132150
session = check_resolve_session(session)
133151
keyspace = check_resolve_keyspace(keyspace)
@@ -143,6 +161,10 @@ def __init__(
143161
self._session = session
144162
self._keyspace = keyspace
145163

164+
self._metadata_indexing_policy = self._normalize_metadata_indexing_policy(
165+
metadata_indexing
166+
)
167+
146168
if setup_mode == SetupMode.SYNC:
147169
self._apply_schema()
148170
elif setup_mode != SetupMode.OFF:
@@ -156,22 +178,22 @@ def __init__(
156178
f"""
157179
INSERT INTO {keyspace}.{node_table} (
158180
content_id, kind, text_content, text_embedding, link_to_tags,
159-
link_from_tags, metadata_blob, links_blob
160-
) VALUES (?, '{Kind.passage}', ?, ?, ?, ?, ?, ?)
181+
link_from_tags, attributes_blob, metadata_s, links_blob
182+
) VALUES (?, '{Kind.passage}', ?, ?, ?, ?, ?, ?, ?)
161183
""" # noqa: S608
162184
)
163185

164186
self._query_by_id = session.prepare(
165187
f"""
166-
SELECT content_id, kind, text_content, metadata_blob, links_blob
188+
SELECT content_id, kind, text_content, attributes_blob, links_blob
167189
FROM {keyspace}.{node_table}
168190
WHERE content_id = ?
169191
""" # noqa: S608
170192
)
171193

172194
self._query_by_embedding = session.prepare(
173195
f"""
174-
SELECT content_id, kind, text_content, metadata_blob, links_blob
196+
SELECT content_id, kind, text_content, attributes_blob, links_blob
175197
FROM {keyspace}.{node_table}
176198
ORDER BY text_embedding ANN OF ?
177199
LIMIT ?
@@ -253,7 +275,8 @@ def _apply_schema(self) -> None:
253275
254276
link_to_tags SET<TUPLE<TEXT, TEXT>>,
255277
link_from_tags SET<TUPLE<TEXT, TEXT>>,
256-
metadata_blob TEXT,
278+
attributes_blob TEXT,
279+
metadata_s MAP<TEXT,TEXT>,
257280
links_blob TEXT,
258281
259282
PRIMARY KEY (content_id)
@@ -273,6 +296,12 @@ def _apply_schema(self) -> None:
273296
USING 'StorageAttachedIndex';
274297
""")
275298

299+
self._session.execute(f"""
300+
CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_metadata_index
301+
ON {self._keyspace}.{self._node_table}(ENTRIES(metadata_s))
302+
USING 'StorageAttachedIndex';
303+
""")
304+
276305
def _concurrent_queries(self) -> ConcurrentQueries:
277306
return ConcurrentQueries(self._session)
278307

@@ -311,7 +340,20 @@ def add_nodes(
311340
if tag.direction in {"out", "bidir"}:
312341
link_to_tags.add((tag.kind, tag.tag))
313342

314-
metadata_blob = _serialize_metadata(metadata)
343+
attributes_dict = {
344+
k: self._coerce_string(v)
345+
for k, v in metadata.items()
346+
if not _is_metadata_field_indexed(k, self._metadata_indexing_policy)
347+
}
348+
attributes_blob = _serialize_metadata(attributes_dict)
349+
350+
metadata_indexed_dict = {
351+
k: v
352+
for k, v in metadata.items()
353+
if _is_metadata_field_indexed(k, self._metadata_indexing_policy)
354+
}
355+
metadata_s = {k: self._coerce_string(v) for k, v in metadata_indexed_dict.items()}
356+
315357
links_blob = _serialize_links(links)
316358
cq.execute(
317359
self._insert_passage,
@@ -321,7 +363,8 @@ def add_nodes(
321363
text_embedding,
322364
link_to_tags,
323365
link_from_tags,
324-
metadata_blob,
366+
attributes_blob,
367+
metadata_s,
325368
links_blob,
326369
),
327370
)
@@ -668,3 +711,85 @@ def add_targets(rows: Iterable[Any]) -> None:
668711
# TODO: Consider a combined limit based on the similarity and/or
669712
# predicated MMR score?
670713
return targets.values()
714+
715+
@staticmethod
716+
def _normalize_metadata_indexing_policy(
717+
metadata_indexing: Union[Tuple[str, Iterable[str]], str]
718+
) -> MetadataIndexingPolicy:
719+
mode: MetadataIndexingMode
720+
fields: Set[str]
721+
# metadata indexing policy normalization:
722+
if isinstance(metadata_indexing, str):
723+
if metadata_indexing.lower() == "all":
724+
mode, fields = (MetadataIndexingMode.DEFAULT_TO_SEARCHABLE, set())
725+
elif metadata_indexing.lower() == "none":
726+
mode, fields = (MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE, set())
727+
else:
728+
raise ValueError(
729+
f"Unsupported metadata_indexing value '{metadata_indexing}'"
730+
)
731+
else:
732+
assert len(metadata_indexing) == 2
733+
# it's a 2-tuple (mode, fields) still to normalize
734+
_mode, _field_spec = metadata_indexing
735+
fields = {_field_spec} if isinstance(_field_spec, str) else set(_field_spec)
736+
if _mode.lower() in {
737+
"default_to_unsearchable",
738+
"allowlist",
739+
"allow",
740+
"allow_list",
741+
}:
742+
mode = MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE
743+
elif _mode.lower() in {
744+
"default_to_searchable",
745+
"denylist",
746+
"deny",
747+
"deny_list",
748+
}:
749+
mode = MetadataIndexingMode.DEFAULT_TO_SEARCHABLE
750+
else:
751+
raise ValueError(
752+
f"Unsupported metadata indexing mode specification '{_mode}'"
753+
)
754+
return (mode, fields)
755+
756+
def _split_metadata_fields(self, md_dict: Dict[str, Any]) -> Dict[str, Any]:
757+
"""
758+
Split the *indexed* part of the metadata in separate parts,
759+
one per Cassandra column.
760+
761+
Currently: everything gets cast to a string and goes to a single table
762+
column. This means:
763+
- strings are fine
764+
- floats and integers v: they are cast to str(v)
765+
- booleans: 'true'/'false' (JSON style)
766+
- None => 'null' (JSON style)
767+
- anything else v => str(v), no questions asked
768+
769+
Caveat: one gets strings back when reading metadata
770+
"""
771+
772+
# TODO: more care about types here
773+
stringy_part = {k: self._coerce_string(v) for k, v in md_dict.items()}
774+
return {
775+
"metadata_s": stringy_part,
776+
}
777+
778+
@staticmethod
779+
def _coerce_string(value: Any) -> str:
780+
if isinstance(value, str):
781+
return value
782+
elif isinstance(value, bool):
783+
# bool MUST come before int in this chain of ifs!
784+
return json.dumps(value)
785+
elif isinstance(value, int):
786+
# we don't want to store '1' and '1.0' differently
787+
# for the sake of metadata-filtered retrieval:
788+
return json.dumps(float(value))
789+
elif isinstance(value, float):
790+
return json.dumps(value)
791+
elif value is None:
792+
return json.dumps(value)
793+
else:
794+
# when all else fails ...
795+
return str(value)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""
2+
Normalization of metadata policy specification options
3+
"""
4+
5+
from ragstack_knowledge_store.graph_store import MetadataIndexingMode, GraphStore
6+
7+
8+
class TestNormalizeMetadataPolicy:
9+
def test_normalize_metadata_policy(self) -> None:
10+
#
11+
mdp1 = GraphStore._normalize_metadata_indexing_policy("all")
12+
assert mdp1 == (MetadataIndexingMode.DEFAULT_TO_SEARCHABLE, set())
13+
#
14+
mdp2 = GraphStore._normalize_metadata_indexing_policy("none")
15+
assert mdp2 == (MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE, set())
16+
#
17+
mdp3 = GraphStore._normalize_metadata_indexing_policy(
18+
("default_to_Unsearchable", ["x", "y"]),
19+
)
20+
assert mdp3 == (MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE, {"x", "y"})
21+
#
22+
mdp4 = GraphStore._normalize_metadata_indexing_policy(
23+
("DenyList", ["z"]),
24+
)
25+
assert mdp4 == (MetadataIndexingMode.DEFAULT_TO_SEARCHABLE, {"z"})
26+
# s
27+
mdp5 = GraphStore._normalize_metadata_indexing_policy(
28+
("deny_LIST", "singlefield")
29+
)
30+
assert mdp5 == (MetadataIndexingMode.DEFAULT_TO_SEARCHABLE, {"singlefield"})
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""
2+
Stringification of everything in the simple metadata handling
3+
"""
4+
5+
from ragstack_knowledge_store.graph_store import GraphStore, SetupMode
6+
7+
8+
class TestMetadataStringCoercion:
9+
def test_metadata_string_coercion(self) -> None:
10+
md_dict = {
11+
"integer": 1,
12+
"float": 2.0,
13+
"boolean": True,
14+
"null": None,
15+
"string": "letter E",
16+
"something": RuntimeError("You cannot do this!"),
17+
}
18+
19+
stringified = {k: GraphStore._coerce_string(v) for k, v in md_dict.items()}
20+
21+
expected = {
22+
"integer": "1.0",
23+
"float": "2.0",
24+
"boolean": "true",
25+
"null": "null",
26+
"string": "letter E",
27+
"something": str(RuntimeError("You cannot do this!")),
28+
}
29+
30+
assert stringified == expected

0 commit comments

Comments
 (0)