Skip to content

Commit d764289

Browse files
committed
progress supporting metadata search
1 parent 3c1689f commit d764289

File tree

2 files changed

+181
-47
lines changed

2 files changed

+181
-47
lines changed

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

Lines changed: 80 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
CONTENT_ID = "content_id"
2929

30+
CONTENT_COLUMNS = "content_id, kind, text_content, attributes_blob, metadata_s, links_blob"
31+
32+
SELECT_CQL_TEMPLATE = "SELECT {columns} FROM {table_name} {where_clause} {limit_clause};"
3033

3134
@dataclass
3235
class Node:
@@ -105,8 +108,10 @@ def _deserialize_links(json_blob: Optional[str]) -> Set[Link]:
105108

106109

107110
def _row_to_node(row: Any) -> Node:
108-
metadata_s = row.get("metadata_s", {})
109-
attributes_blob = row.get("attributes_blob")
111+
metadata_s = row.metadata_s
112+
if metadata_s is None:
113+
metadata_s = {}
114+
attributes_blob = row.attributes_blob
110115
attributes_dict = _deserialize_metadata(attributes_blob) if attributes_blob is not None else {}
111116
links = _deserialize_links(row.links_blob)
112117
return Node(
@@ -164,7 +169,7 @@ def __init__(
164169
self._keyspace = keyspace
165170

166171
self._metadata_indexing_policy = self._normalize_metadata_indexing_policy(
167-
metadata_indexing
172+
metadata_indexing=metadata_indexing,
168173
)
169174

170175
if setup_mode == SetupMode.SYNC:
@@ -187,15 +192,15 @@ def __init__(
187192

188193
self._query_by_id = session.prepare(
189194
f"""
190-
SELECT content_id, kind, text_content, attributes_blob, links_blob
195+
SELECT {CONTENT_COLUMNS}
191196
FROM {keyspace}.{node_table}
192197
WHERE content_id = ?
193198
""" # noqa: S608
194199
)
195200

196201
self._query_by_embedding = session.prepare(
197202
f"""
198-
SELECT content_id, kind, text_content, attributes_blob, links_blob
203+
SELECT {CONTENT_COLUMNS}
199204
FROM {keyspace}.{node_table}
200205
ORDER BY text_embedding ANN OF ?
201206
LIMIT ?
@@ -307,6 +312,25 @@ def _apply_schema(self) -> None:
307312
def _concurrent_queries(self) -> ConcurrentQueries:
308313
return ConcurrentQueries(self._session)
309314

315+
def _parse_metadata(self, metadata: Dict[str, Any], is_query: bool) -> Tuple[str, Dict[str,str]]:
316+
attributes_dict = {
317+
k: self._coerce_string(v)
318+
for k, v in metadata.items()
319+
if not _is_metadata_field_indexed(k, self._metadata_indexing_policy)
320+
}
321+
if is_query and len(attributes_dict) > 0:
322+
raise ValueError("Non-indexed metadata fields cannot be used in queries.")
323+
attributes_blob = _serialize_metadata(attributes_dict)
324+
325+
metadata_indexed_dict = {
326+
k: v
327+
for k, v in metadata.items()
328+
if _is_metadata_field_indexed(k, self._metadata_indexing_policy)
329+
}
330+
metadata_s = {k: self._coerce_string(v) for k, v in metadata_indexed_dict.items()}
331+
return (attributes_blob, metadata_s)
332+
333+
310334
# TODO: Async (aadd_nodes)
311335
def add_nodes(
312336
self,
@@ -342,19 +366,7 @@ def add_nodes(
342366
if tag.direction in {"out", "bidir"}:
343367
link_to_tags.add((tag.kind, tag.tag))
344368

345-
attributes_dict = {
346-
k: self._coerce_string(v)
347-
for k, v in metadata.items()
348-
if not _is_metadata_field_indexed(k, self._metadata_indexing_policy)
349-
}
350-
attributes_blob = _serialize_metadata(attributes_dict)
351-
352-
metadata_indexed_dict = {
353-
k: v
354-
for k, v in metadata.items()
355-
if _is_metadata_field_indexed(k, self._metadata_indexing_policy)
356-
}
357-
metadata_s = {k: self._coerce_string(v) for k, v in metadata_indexed_dict.items()}
369+
attributes_blob, metadata_s = self._parse_metadata(metadata=metadata, is_query=False)
358370

359371
links_blob = _serialize_links(links)
360372
cq.execute(
@@ -380,7 +392,7 @@ def _nodes_with_ids(
380392
results: Dict[str, Optional[Node]] = {}
381393
with self._concurrent_queries() as cq:
382394

383-
def add_nodes(rows: Iterable[Any]) -> None:
395+
def node_callback(rows: Iterable[Any]) -> None:
384396
# Should always be exactly one row here. We don't need to check
385397
# 1. The query is for a `ID == ?` query on the primary key.
386398
# 2. If it doesn't exist, the `get_result` method below will
@@ -393,7 +405,7 @@ def add_nodes(rows: Iterable[Any]) -> None:
393405
# Mark this node ID as being fetched.
394406
results[node_id] = None
395407
cq.execute(
396-
self._query_by_id, parameters=(node_id,), callback=add_nodes
408+
self._query_by_id, parameters=(node_id,), callback=node_callback
397409
)
398410

399411
def get_result(node_id: str) -> Node:
@@ -643,6 +655,18 @@ def similarity_search(
643655
for row in self._session.execute(self._query_by_embedding, (embedding, k)):
644656
yield _row_to_node(row)
645657

658+
def metadata_search(self, metadata: Dict[str, Any] = {}, n: Optional[int] = 5)-> Iterable[Node]:
659+
query, params = self._get_metadata_search_cql(metadata=metadata, n=n)
660+
661+
prepared_query = self._session.prepare(query)
662+
663+
for row in self._session.execute(prepared_query, params):
664+
yield _row_to_node(row)
665+
666+
def get_node(self, id: str) -> Node:
667+
return self._nodes_with_ids(ids=[id])[0]
668+
669+
646670
def _get_outgoing_tags(
647671
self,
648672
source_ids: Iterable[str],
@@ -755,28 +779,6 @@ def _normalize_metadata_indexing_policy(
755779
)
756780
return (mode, fields)
757781

758-
def _split_metadata_fields(self, md_dict: Dict[str, Any]) -> Dict[str, Any]:
759-
"""
760-
Split the *indexed* part of the metadata in separate parts,
761-
one per Cassandra column.
762-
763-
Currently: everything gets cast to a string and goes to a single table
764-
column. This means:
765-
- strings are fine
766-
- floats and integers v: they are cast to str(v)
767-
- booleans: 'true'/'false' (JSON style)
768-
- None => 'null' (JSON style)
769-
- anything else v => str(v), no questions asked
770-
771-
Caveat: one gets strings back when reading metadata
772-
"""
773-
774-
# TODO: more care about types here
775-
stringy_part = {k: self._coerce_string(v) for k, v in md_dict.items()}
776-
return {
777-
"metadata_s": stringy_part,
778-
}
779-
780782
@staticmethod
781783
def _coerce_string(value: Any) -> str:
782784
if isinstance(value, str):
@@ -794,4 +796,39 @@ def _coerce_string(value: Any) -> str:
794796
return json.dumps(value)
795797
else:
796798
# when all else fails ...
797-
return str(value)
799+
return str(value)
800+
801+
def _extract_where_clause_blocks(
802+
self, metadata: Dict[str, Any]
803+
) -> Tuple[str, List[Any]]:
804+
805+
attributes_blob, metadata_s = self._parse_metadata(metadata=metadata, is_query=True)
806+
807+
if len(metadata_s) == 0:
808+
return "", []
809+
810+
wc_blocks: List[str] = []
811+
vals_list: List[Any] = []
812+
813+
for k, v in sorted(metadata_s.items()):
814+
wc_blocks.append(f"metadata_s['{k}'] = ?")
815+
vals_list.append(v)
816+
817+
where_clause = "WHERE " + " AND ".join(wc_blocks)
818+
return where_clause, vals_list
819+
820+
821+
def _get_metadata_search_cql(self, n: int, metadata: Dict[str, Any]) -> Tuple[str, Tuple[Any, ...]]:
822+
where_clause, get_cql_vals = self._extract_where_clause_blocks(metadata=metadata)
823+
limit_clause = "LIMIT ?"
824+
limit_cql_vals = [n]
825+
select_vals = tuple(list(get_cql_vals) + limit_cql_vals)
826+
#
827+
select_cql = SELECT_CQL_TEMPLATE.format(
828+
columns=CONTENT_COLUMNS,
829+
table_name=f"{self._keyspace}.{self._node_table}",
830+
where_clause=where_clause,
831+
limit_clause=limit_clause,
832+
833+
)
834+
return select_cql, select_vals

libs/knowledge-store/tests/integration_tests/test_graph_store.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import secrets
2-
from typing import Callable, Iterator, List
2+
from typing import Callable, Iterator, List, Optional
33

44
import pytest
55
from dotenv import load_dotenv
66
from ragstack_knowledge_store import EmbeddingModel
7-
from ragstack_knowledge_store.graph_store import GraphStore
7+
from ragstack_knowledge_store.graph_store import GraphStore, Node
88
from ragstack_tests_utils import LocalCassandraTestStore
99

1010
load_dotenv()
@@ -47,7 +47,7 @@ def graph_store_factory(
4747

4848
embedding = DummyEmbeddingModel()
4949

50-
def _make_graph_store() -> GraphStore:
50+
def _make_graph_store(metadata_indexing: Optional[str] = "all") -> GraphStore:
5151
name = secrets.token_hex(8)
5252

5353
node_table = f"nodes_{name}"
@@ -56,16 +56,113 @@ def _make_graph_store() -> GraphStore:
5656
session=session,
5757
keyspace=KEYSPACE,
5858
node_table=node_table,
59+
metadata_indexing=metadata_indexing,
5960
)
6061

6162
yield _make_graph_store
6263

6364
session.shutdown()
6465

6566

66-
def test_graph_store_creation(graph_store_factory: Callable[[], GraphStore]) -> None:
67+
def test_graph_store_creation(graph_store_factory: Callable[[str], GraphStore]) -> None:
6768
"""Test that a graph store can be created.
6869
6970
This verifies the schema can be applied and the queries prepared.
7071
"""
7172
graph_store_factory()
73+
74+
def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore]) -> None:
75+
gs = graph_store_factory()
76+
77+
gs.add_nodes([Node(text="bb1", id="row1")])
78+
gotten1 = gs.get_node(id="row1")
79+
assert gotten1 == Node(text="bb1", id="row1", metadata={})
80+
81+
gs.add_nodes([Node(text=None, id="row2", metadata={})])
82+
gotten2 = gs.get_node(id="row2")
83+
assert gotten2 == Node(text=None, id="row2", metadata={})
84+
85+
md3 = {"a": 1, "b": "Bee", "c": True}
86+
md3_string = {"a": "1.0", "b": "Bee", "c": "true"}
87+
gs.add_nodes([Node(text=None, id="row3", metadata=md3)])
88+
gotten3 = gs.get_node(id="row3")
89+
assert gotten3 == Node(text=None, id="row3", metadata=md3_string)
90+
91+
md4 = {"c1": True, "c2": True, "c3": True}
92+
md4_string = {"c1": "true", "c2": "true", "c3": "true"}
93+
gs.add_nodes([Node(text=None, id="row4", metadata=md4)])
94+
gotten4 = gs.get_node(id="row4")
95+
assert gotten4 == Node(text=None, id="row4", metadata=md4_string)
96+
97+
# metadata searches:
98+
md_gotten3a = list(gs.metadata_search(metadata={"a": 1}))[0]
99+
assert md_gotten3a == gotten3
100+
md_gotten3b = list(gs.metadata_search(metadata={"b": "Bee", "c": True}))[0]
101+
assert md_gotten3b == gotten3
102+
md_gotten4 = list(gs.metadata_search(metadata={"c1": True, "c3": True}))[0]
103+
assert md_gotten4 == gotten4
104+
105+
# 'search' proper
106+
gs.add_nodes([
107+
Node(text=None, id="twin_a", metadata={"twin": True, "index": 0}),
108+
Node(text=None, id="twin_b", metadata={"twin": True, "index": 1})
109+
])
110+
md_twins_gotten = sorted(
111+
list(gs.metadata_search(metadata={"twin": True})),
112+
key=lambda res: int(float(res.metadata["index"]))
113+
)
114+
expected = [
115+
Node(text=None, id="twin_a", metadata={"twin": "true", "index": "0.0"}),
116+
Node(text=None, id="twin_b", metadata={"twin": "true", "index": "1.0"}),
117+
]
118+
assert md_twins_gotten == expected
119+
assert list(gs.metadata_search(metadata={"fake": True})) == []
120+
121+
def test_graph_store_metadata_routing(graph_store_factory: Callable[[str], GraphStore]) -> None:
122+
test_md = {"mds": "string", "mdn": 255, "mdb": True}
123+
test_md_string = {"mds": "string", "mdn": "255.0", "mdb": "true"}
124+
125+
gs_all = graph_store_factory(metadata_indexing="all")
126+
gs_all.add_nodes([Node(id="row1", text="bb1", metadata=test_md)])
127+
gotten_all = list(gs_all.metadata_search(metadata={"mds": "string"}))[0]
128+
assert gotten_all.metadata == test_md_string
129+
#
130+
gs_none = graph_store_factory(metadata_indexing="none")
131+
gs_none.add_nodes([Node(id="row1", text="bb1", metadata=test_md)])
132+
with pytest.raises(ValueError):
133+
# querying on non-indexed metadata fields:
134+
list(gs_none.metadata_search(metadata={"mds": "string"}))
135+
gotten_none = gs_none.get_node(id="row1")
136+
assert gotten_none is not None
137+
assert gotten_none.metadata == test_md_string
138+
#
139+
test_md_allowdeny = {
140+
"mdas": "MDAS",
141+
"mdds": "MDDS",
142+
"mdan": 255,
143+
"mddn": 127,
144+
"mdab": True,
145+
"mddb": True,
146+
}
147+
test_md_allowdeny_string = {
148+
"mdas": "MDAS",
149+
"mdds": "MDDS",
150+
"mdan": "255.0",
151+
"mddn": "127.0",
152+
"mdab": "true",
153+
"mddb": "true",
154+
}
155+
#
156+
gs_allow = graph_store_factory(metadata_indexing=("allow", {"mdas", "mdan", "mdab"}))
157+
gs_allow.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)])
158+
with pytest.raises(ValueError):
159+
list(gs_allow.metadata_search(metadata={"mdds": "MDDS"}))
160+
gotten_allow = list(gs_allow.metadata_search(metadata={"mdas": "MDAS"}))[0]
161+
assert gotten_allow.metadata == test_md_allowdeny_string
162+
#
163+
gs_deny = graph_store_factory(metadata_indexing=("deny", {"mdds", "mddn", "mddb"}))
164+
gs_deny.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)])
165+
with pytest.raises(ValueError):
166+
list(gs_deny.metadata_search(metadata={"mdds": "MDDS"}))
167+
gotten_deny = list(gs_deny.metadata_search(metadata={"mdas": "MDAS"}))[0]
168+
assert gotten_deny.metadata == test_md_allowdeny_string

0 commit comments

Comments
 (0)