Skip to content

Commit 5b0bca4

Browse files
committed
restored full metadata_blob
1 parent 47894df commit 5b0bca4

File tree

2 files changed

+46
-80
lines changed

2 files changed

+46
-80
lines changed

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

Lines changed: 35 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929

3030
CONTENT_ID = "content_id"
3131

32-
CONTENT_COLUMNS = (
33-
"content_id, kind, text_content, attributes_blob, metadata_s, links_blob"
34-
)
32+
CONTENT_COLUMNS = "content_id, kind, text_content, links_blob, metadata_blob"
3533

3634
SELECT_CQL_TEMPLATE = (
3735
"SELECT {columns} FROM {table_name} {where_clause} {order_clause} {limit_clause};"
@@ -119,18 +117,12 @@ def _deserialize_links(json_blob: Optional[str]) -> Set[Link]:
119117

120118

121119
def _row_to_node(row: Any) -> Node:
122-
metadata_s = row.metadata_s
123-
if metadata_s is None:
124-
metadata_s = {}
125-
attributes_blob = row.attributes_blob
126-
attributes_dict = (
127-
_deserialize_metadata(attributes_blob) if attributes_blob is not None else {}
128-
)
120+
metadata = _deserialize_metadata(row.metadata_blob)
129121
links = _deserialize_links(row.links_blob)
130122
return Node(
131123
id=row.content_id,
132124
text=row.text_content,
133-
metadata={**attributes_dict, **metadata_s},
125+
metadata=metadata,
134126
links=links,
135127
)
136128

@@ -198,7 +190,7 @@ def __init__(
198190
f"""
199191
INSERT INTO {keyspace}.{node_table} (
200192
content_id, kind, text_content, text_embedding, link_to_tags,
201-
link_from_tags, attributes_blob, metadata_s, links_blob
193+
link_from_tags, links_blob, metadata_blob, metadata_s
202194
) VALUES (?, '{Kind.passage}', ?, ?, ?, ?, ?, ?, ?)
203195
""" # noqa: S608
204196
)
@@ -265,9 +257,9 @@ def _apply_schema(self) -> None:
265257
266258
link_to_tags SET<TUPLE<TEXT, TEXT>>,
267259
link_from_tags SET<TUPLE<TEXT, TEXT>>,
268-
attributes_blob TEXT,
269-
metadata_s MAP<TEXT,TEXT>,
270260
links_blob TEXT,
261+
metadata_blob TEXT,
262+
metadata_s MAP<TEXT,TEXT>,
271263
272264
PRIMARY KEY (content_id)
273265
)
@@ -287,36 +279,14 @@ def _apply_schema(self) -> None:
287279
""")
288280

289281
self._session.execute(f"""
290-
CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_metadata_index
282+
CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_metadata_s_index
291283
ON {self.table_name()}(ENTRIES(metadata_s))
292284
USING 'StorageAttachedIndex';
293285
""")
294286

295287
def _concurrent_queries(self) -> ConcurrentQueries:
296288
return ConcurrentQueries(self._session)
297289

298-
def _parse_metadata(
299-
self, metadata: Dict[str, Any], is_query: bool
300-
) -> Tuple[str, Dict[str, str]]:
301-
attributes_dict = {
302-
k: self._coerce_string(v)
303-
for k, v in metadata.items()
304-
if not _is_metadata_field_indexed(k, self._metadata_indexing_policy)
305-
}
306-
if is_query and len(attributes_dict) > 0:
307-
raise ValueError("Non-indexed metadata fields cannot be used in queries.")
308-
attributes_blob = _serialize_metadata(attributes_dict)
309-
310-
metadata_indexed_dict = {
311-
k: v
312-
for k, v in metadata.items()
313-
if _is_metadata_field_indexed(k, self._metadata_indexing_policy)
314-
}
315-
metadata_s = {
316-
k: self._coerce_string(v) for k, v in metadata_indexed_dict.items()
317-
}
318-
return (attributes_blob, metadata_s)
319-
320290
# TODO: Async (aadd_nodes)
321291
def add_nodes(
322292
self,
@@ -352,10 +322,13 @@ def add_nodes(
352322
if tag.direction in {"out", "bidir"}:
353323
link_to_tags.add((tag.kind, tag.tag))
354324

355-
attributes_blob, metadata_s = self._parse_metadata(
356-
metadata=metadata, is_query=False
357-
)
325+
metadata_s = {
326+
k: self._coerce_string(v)
327+
for k, v in metadata.items()
328+
if _is_metadata_field_indexed(k, self._metadata_indexing_policy)
329+
}
358330

331+
metadata_blob = _serialize_metadata(metadata)
359332
links_blob = _serialize_links(links)
360333
cq.execute(
361334
self._insert_passage,
@@ -365,9 +338,9 @@ def add_nodes(
365338
text_embedding,
366339
link_to_tags,
367340
link_from_tags,
368-
attributes_blob,
369-
metadata_s,
370341
links_blob,
342+
metadata_blob,
343+
metadata_s,
371344
),
372345
)
373346

@@ -413,7 +386,7 @@ def mmr_traversal_search(
413386
adjacent_k: int = 10,
414387
lambda_mult: float = 0.5,
415388
score_threshold: float = float("-inf"),
416-
metadata: Dict[str, Any] = {},
389+
metadata_filter: Dict[str, Any] = {},
417390
) -> Iterable[Node]:
418391
"""Retrieve documents from this graph store using MMR-traversal.
419392
@@ -439,7 +412,7 @@ def mmr_traversal_search(
439412
diversity and 1 to minimum diversity. Defaults to 0.5.
440413
score_threshold: Only documents with a score greater than or equal
441414
this threshold will be chosen. Defaults to -infinity.
442-
metadata: Optional metadata to filter the results.
415+
metadata_filter: Optional metadata to filter the results.
443416
"""
444417
query_embedding = self._embedding.embed_query(query)
445418
helper = MmrHelper(
@@ -458,7 +431,7 @@ def fetch_initial_candidates() -> None:
458431
query, params = self._get_search_cql(
459432
limit=fetch_k,
460433
columns="content_id, text_embedding, link_to_tags",
461-
metadata=metadata,
434+
metadata=metadata_filter,
462435
embedding=query_embedding,
463436
)
464437

@@ -539,7 +512,7 @@ def traversal_search(
539512
*,
540513
k: int = 4,
541514
depth: int = 1,
542-
metadata: Dict[str, Any] = {},
515+
metadata_filter: Dict[str, Any] = {},
543516
) -> Iterable[Node]:
544517
"""Retrieve documents from this knowledge store.
545518
@@ -552,7 +525,7 @@ def traversal_search(
552525
k: The number of Documents to return from the initial vector search.
553526
Defaults to 4.
554527
depth: The maximum depth of edges to traverse. Defaults to 1.
555-
metadata: Optional metadata to filter the results.
528+
metadata_filter: Optional metadata to filter the results.
556529
557530
Returns:
558531
Collection of retrieved documents.
@@ -639,8 +612,9 @@ def visit_targets(d: int, targets: Sequence[Any]) -> None:
639612

640613
query_embedding = self._embedding.embed_query(query)
641614
query, params = self._get_search_cql(
615+
columns="content_id, link_to_tags",
642616
limit=k,
643-
metadata=metadata,
617+
metadata=metadata_filter,
644618
embedding=query_embedding,
645619
)
646620

@@ -656,11 +630,11 @@ def similarity_search(
656630
self,
657631
embedding: List[float],
658632
k: int = 4,
659-
metadata: Dict[str, Any] = {},
633+
metadata_filter: Dict[str, Any] = {},
660634
) -> Iterable[Node]:
661635
"""Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501
662636
query, params = self._get_search_cql(
663-
embedding=embedding, limit=k, metadata=metadata
637+
embedding=embedding, limit=k, metadata=metadata_filter
664638
)
665639

666640
for row in self._session.execute(query, params):
@@ -813,17 +787,20 @@ def _coerce_string(value: Any) -> str:
813787
def _extract_where_clause_blocks(
814788
self, metadata: Dict[str, Any]
815789
) -> Tuple[str, List[Any]]:
816-
_, metadata_s = self._parse_metadata(metadata=metadata, is_query=True)
817-
818-
if len(metadata_s) == 0:
819-
return "", []
820-
821790
wc_blocks: List[str] = []
822791
vals_list: List[Any] = []
823792

824-
for k, v in sorted(metadata_s.items()):
825-
wc_blocks.append(f"metadata_s['{k}'] = ?")
826-
vals_list.append(v)
793+
for key, value in sorted(metadata.items()):
794+
if _is_metadata_field_indexed(key, self._metadata_indexing_policy):
795+
wc_blocks.append(f"metadata_s['{key}'] = ?")
796+
vals_list.append(self._coerce_string(value=value))
797+
else:
798+
raise ValueError(
799+
"Non-indexed metadata fields cannot be used in queries."
800+
)
801+
802+
if len(wc_blocks) == 0:
803+
return "", []
827804

828805
where_clause = "WHERE " + " AND ".join(wc_blocks)
829806
return where_clause, vals_list

libs/knowledge-store/tests/integration_tests/test_graph_store.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ruff: noqa: PT011, RUF015
22
import math
33
import secrets
4-
from typing import Callable, Iterable, Iterator, List, Optional
4+
from typing import Callable, Iterable, Iterator, List
55

66
import numpy as np
77
import pytest
@@ -272,8 +272,8 @@ def test_metadata(graph_store_factory: Callable[[], GraphStore]) -> None:
272272
Link(direction="in", kind="hyperlink", tag="http://a"),
273273
Link(direction="bidir", kind="other", tag="foo"),
274274
}
275-
def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore]) -> None:
276-
gs = graph_store_factory()
275+
276+
277277
def test_graph_store_metadata(
278278
graph_store_factory: Callable[[MetadataIndexingType], GraphStore],
279279
) -> None:
@@ -288,16 +288,14 @@ def test_graph_store_metadata(
288288
assert gotten2 == Node(text="", id="row2", metadata={})
289289

290290
md3 = {"a": 1, "b": "Bee", "c": True}
291-
md3_string = {"a": "1.0", "b": "Bee", "c": "true"}
292291
gs.add_nodes([Node(text="", id="row3", metadata=md3)])
293292
gotten3 = gs.get_node(content_id="row3")
294-
assert gotten3 == Node(text="", id="row3", metadata=md3_string)
293+
assert gotten3 == Node(text="", id="row3", metadata=md3)
295294

296295
md4 = {"c1": True, "c2": True, "c3": True}
297-
md4_string = {"c1": "true", "c2": "true", "c3": "true"}
298296
gs.add_nodes([Node(text="", id="row4", metadata=md4)])
299297
gotten4 = gs.get_node(content_id="row4")
300-
assert gotten4 == Node(text="", id="row4", metadata=md4_string)
298+
assert gotten4 == Node(text="", id="row4", metadata=md4)
301299

302300
# metadata searches:
303301
md_gotten3a = list(gs.metadata_search(metadata={"a": 1}))[0]
@@ -319,8 +317,8 @@ def test_graph_store_metadata(
319317
key=lambda res: int(float(res.metadata["index"])),
320318
)
321319
expected = [
322-
Node(text="", id="twin_a", metadata={"twin": "true", "index": "0.0"}),
323-
Node(text="", id="twin_b", metadata={"twin": "true", "index": "1.0"}),
320+
Node(text="", id="twin_a", metadata={"twin": True, "index": 0}),
321+
Node(text="", id="twin_b", metadata={"twin": True, "index": 1}),
324322
]
325323
assert md_twins_gotten == expected
326324
assert list(gs.metadata_search(metadata={"fake": True})) == []
@@ -330,20 +328,19 @@ def test_graph_store_metadata_routing(
330328
graph_store_factory: Callable[[MetadataIndexingType], GraphStore],
331329
) -> None:
332330
test_md = {"mds": "string", "mdn": 255, "mdb": True}
333-
test_md_string = {"mds": "string", "mdn": "255.0", "mdb": "true"}
334331

335332
gs_all = graph_store_factory("all")
336333
gs_all.add_nodes([Node(id="row1", text="bb1", metadata=test_md)])
337334
gotten_all = list(gs_all.metadata_search(metadata={"mds": "string"}))[0]
338-
assert gotten_all.metadata == test_md_string
335+
assert gotten_all.metadata == test_md
339336
gs_none = graph_store_factory("none")
340337
gs_none.add_nodes([Node(id="row1", text="bb1", metadata=test_md)])
341338
with pytest.raises(ValueError):
342339
# querying on non-indexed metadata fields:
343340
list(gs_none.metadata_search(metadata={"mds": "string"}))
344341
gotten_none = gs_none.get_node(content_id="row1")
345342
assert gotten_none is not None
346-
assert gotten_none.metadata == test_md_string
343+
assert gotten_none.metadata == test_md
347344
test_md_allowdeny = {
348345
"mdas": "MDAS",
349346
"mdds": "MDDS",
@@ -352,23 +349,15 @@ def test_graph_store_metadata_routing(
352349
"mdab": True,
353350
"mddb": True,
354351
}
355-
test_md_allowdeny_string = {
356-
"mdas": "MDAS",
357-
"mdds": "MDDS",
358-
"mdan": "255.0",
359-
"mddn": "127.0",
360-
"mdab": "true",
361-
"mddb": "true",
362-
}
363352
gs_allow = graph_store_factory(("allow", {"mdas", "mdan", "mdab"}))
364353
gs_allow.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)])
365354
with pytest.raises(ValueError):
366355
list(gs_allow.metadata_search(metadata={"mdds": "MDDS"}))
367356
gotten_allow = list(gs_allow.metadata_search(metadata={"mdas": "MDAS"}))[0]
368-
assert gotten_allow.metadata == test_md_allowdeny_string
357+
assert gotten_allow.metadata == test_md_allowdeny
369358
gs_deny = graph_store_factory(("deny", {"mdds", "mddn", "mddb"}))
370359
gs_deny.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)])
371360
with pytest.raises(ValueError):
372361
list(gs_deny.metadata_search(metadata={"mdds": "MDDS"}))
373362
gotten_deny = list(gs_deny.metadata_search(metadata={"mdas": "MDAS"}))[0]
374-
assert gotten_deny.metadata == test_md_allowdeny_string
363+
assert gotten_deny.metadata == test_md_allowdeny

0 commit comments

Comments
 (0)