Skip to content

Commit c5eafb9

Browse files
committed
fmt
1 parent 3a66048 commit c5eafb9

File tree

4 files changed

+104
-65
lines changed

4 files changed

+104
-65
lines changed

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# ruff: noqa: B006
2+
13
import json
24
import re
35
import secrets
@@ -16,7 +18,7 @@
1618
cast,
1719
)
1820

19-
from cassandra.cluster import ConsistencyLevel, Session, ResponseFuture
21+
from cassandra.cluster import ConsistencyLevel, Session
2022
from cassio.config import check_resolve_keyspace, check_resolve_session
2123

2224
from ._mmr_helper import MmrHelper
@@ -27,9 +29,14 @@
2729

2830
CONTENT_ID = "content_id"
2931

30-
CONTENT_COLUMNS = "content_id, kind, text_content, attributes_blob, metadata_s, links_blob"
32+
CONTENT_COLUMNS = (
33+
"content_id, kind, text_content, attributes_blob, metadata_s, links_blob"
34+
)
35+
36+
SELECT_CQL_TEMPLATE = (
37+
"SELECT {columns} FROM {table_name} {where_clause} {order_clause} {limit_clause};"
38+
)
3139

32-
SELECT_CQL_TEMPLATE = "SELECT {columns} FROM {table_name} {where_clause} {order_clause} {limit_clause};"
3340

3441
@dataclass
3542
class Node:
@@ -52,20 +59,25 @@ class SetupMode(Enum):
5259
ASYNC = 2
5360
OFF = 3
5461

62+
5563
class MetadataIndexingMode(Enum):
64+
"""Mode used to index metadata."""
65+
5666
DEFAULT_TO_UNSEARCHABLE = 1
5767
DEFAULT_TO_SEARCHABLE = 2
5868

69+
5970
MetadataIndexingPolicy = Tuple[MetadataIndexingMode, Set[str]]
6071

72+
6173
def _is_metadata_field_indexed(field_name: str, policy: MetadataIndexingPolicy) -> bool:
6274
p_mode, p_fields = policy
6375
if p_mode == MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE:
6476
return field_name in p_fields
65-
elif p_mode == MetadataIndexingMode.DEFAULT_TO_SEARCHABLE:
77+
if p_mode == MetadataIndexingMode.DEFAULT_TO_SEARCHABLE:
6678
return field_name not in p_fields
67-
else:
68-
raise ValueError(f"Unexpected metadata indexing mode {p_mode}")
79+
raise ValueError(f"Unexpected metadata indexing mode {p_mode}")
80+
6981

7082
def _serialize_metadata(md: Dict[str, Any]) -> str:
7183
if isinstance(md.get("links"), Set):
@@ -112,7 +124,9 @@ def _row_to_node(row: Any) -> Node:
112124
if metadata_s is None:
113125
metadata_s = {}
114126
attributes_blob = row.attributes_blob
115-
attributes_dict = _deserialize_metadata(attributes_blob) if attributes_blob is not None else {}
127+
attributes_dict = (
128+
_deserialize_metadata(attributes_blob) if attributes_blob is not None else {}
129+
)
116130
links = _deserialize_links(row.links_blob)
117131
return Node(
118132
id=row.content_id,
@@ -237,6 +251,7 @@ def __init__(
237251
)
238252

239253
def table_name(self) -> str:
254+
"""Returns the fully qualified table name."""
240255
return f"{self._keyspace}.{self._node_table}"
241256

242257
def _apply_schema(self) -> None:
@@ -281,7 +296,9 @@ def _apply_schema(self) -> None:
281296
def _concurrent_queries(self) -> ConcurrentQueries:
282297
return ConcurrentQueries(self._session)
283298

284-
def _parse_metadata(self, metadata: Dict[str, Any], is_query: bool) -> Tuple[str, Dict[str,str]]:
299+
def _parse_metadata(
300+
self, metadata: Dict[str, Any], is_query: bool
301+
) -> Tuple[str, Dict[str, str]]:
285302
attributes_dict = {
286303
k: self._coerce_string(v)
287304
for k, v in metadata.items()
@@ -296,10 +313,11 @@ def _parse_metadata(self, metadata: Dict[str, Any], is_query: bool) -> Tuple[str
296313
for k, v in metadata.items()
297314
if _is_metadata_field_indexed(k, self._metadata_indexing_policy)
298315
}
299-
metadata_s = {k: self._coerce_string(v) for k, v in metadata_indexed_dict.items()}
316+
metadata_s = {
317+
k: self._coerce_string(v) for k, v in metadata_indexed_dict.items()
318+
}
300319
return (attributes_blob, metadata_s)
301320

302-
303321
# TODO: Async (aadd_nodes)
304322
def add_nodes(
305323
self,
@@ -335,7 +353,9 @@ def add_nodes(
335353
if tag.direction in {"out", "bidir"}:
336354
link_to_tags.add((tag.kind, tag.tag))
337355

338-
attributes_blob, metadata_s = self._parse_metadata(metadata=metadata, is_query=False)
356+
attributes_blob, metadata_s = self._parse_metadata(
357+
metadata=metadata, is_query=False
358+
)
339359

340360
links_blob = _serialize_links(links)
341361
cq.execute(
@@ -440,7 +460,7 @@ def fetch_initial_candidates() -> None:
440460
limit=fetch_k,
441461
columns="content_id, text_embedding, link_to_tags",
442462
metadata=metadata,
443-
embedding=query_embedding
463+
embedding=query_embedding,
444464
)
445465

446466
fetched = self._session.execute(query=query, parameters=params)
@@ -515,7 +535,12 @@ def fetch_initial_candidates() -> None:
515535
return self._nodes_with_ids(helper.selected_ids)
516536

517537
def traversal_search(
518-
self, query: str, *, k: int = 4, depth: int = 1, metadata: Optional[Dict[str, Any]] = [],
538+
self,
539+
query: str,
540+
*,
541+
k: int = 4,
542+
depth: int = 1,
543+
metadata: Optional[Dict[str, Any]] = [],
519544
) -> Iterable[Node]:
520545
"""Retrieve documents from this knowledge store.
521546
@@ -634,21 +659,26 @@ def similarity_search(
634659
k: int = 4,
635660
metadata: Optional[Dict[str, Any]] = [],
636661
) -> Iterable[Node]:
637-
"""Retrieve nodes similar to the given embedding, optionally filtered by metadata"""
638-
query, params = self._get_search_cql(embedding=embedding, limit=k, metadata=metadata)
662+
"""Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501
663+
query, params = self._get_search_cql(
664+
embedding=embedding, limit=k, metadata=metadata
665+
)
639666

640667
for row in self._session.execute(query, params):
641668
yield _row_to_node(row)
642669

643-
def metadata_search(self, metadata: Dict[str, Any] = {}, n: Optional[int] = 5)-> Iterable[Node]:
670+
def metadata_search(
671+
self, metadata: Dict[str, Any] = {}, n: Optional[int] = 5
672+
) -> Iterable[Node]:
673+
"""Retrieve nodes based on their metadata."""
644674
query, params = self._get_search_cql(metadata=metadata, limit=n)
645675

646676
for row in self._session.execute(query, params):
647677
yield _row_to_node(row)
648678

649-
def get_node(self, id: str) -> Node:
650-
return self._nodes_with_ids(ids=[id])[0]
651-
679+
def get_node(self, content_id: str) -> Node:
680+
"""Get a node by its id."""
681+
return self._nodes_with_ids(ids=[content_id])[0]
652682

653683
def _get_outgoing_tags(
654684
self,
@@ -723,7 +753,7 @@ def add_targets(rows: Iterable[Any]) -> None:
723753

724754
@staticmethod
725755
def _normalize_metadata_indexing_policy(
726-
metadata_indexing: Union[Tuple[str, Iterable[str]], str]
756+
metadata_indexing: Union[Tuple[str, Iterable[str]], str],
727757
) -> MetadataIndexingPolicy:
728758
mode: MetadataIndexingMode
729759
fields: Set[str]
@@ -738,7 +768,10 @@ def _normalize_metadata_indexing_policy(
738768
f"Unsupported metadata_indexing value '{metadata_indexing}'"
739769
)
740770
else:
741-
assert len(metadata_indexing) == 2
771+
if len(metadata_indexing) != 2: # noqa: PLR2004
772+
raise ValueError(
773+
f"Unsupported metadata_indexing value '{metadata_indexing}'."
774+
)
742775
# it's a 2-tuple (mode, fields) still to normalize
743776
_mode, _field_spec = metadata_indexing
744777
fields = {_field_spec} if isinstance(_field_spec, str) else set(_field_spec)
@@ -766,25 +799,21 @@ def _normalize_metadata_indexing_policy(
766799
def _coerce_string(value: Any) -> str:
767800
if isinstance(value, str):
768801
return value
769-
elif isinstance(value, bool):
802+
if isinstance(value, bool):
770803
# bool MUST come before int in this chain of ifs!
771804
return json.dumps(value)
772-
elif isinstance(value, int):
805+
if isinstance(value, int):
773806
# we don't want to store '1' and '1.0' differently
774807
# for the sake of metadata-filtered retrieval:
775808
return json.dumps(float(value))
776-
elif isinstance(value, float):
809+
if isinstance(value, float) or value is None:
777810
return json.dumps(value)
778-
elif value is None:
779-
return json.dumps(value)
780-
else:
781-
# when all else fails ...
782-
return str(value)
811+
# when all else fails ...
812+
return str(value)
783813

784814
def _extract_where_clause_blocks(
785815
self, metadata: Dict[str, Any]
786816
) -> Tuple[str, List[Any]]:
787-
788817
_, metadata_s = self._parse_metadata(metadata=metadata, is_query=True)
789818

790819
if len(metadata_s) == 0:
@@ -800,13 +829,20 @@ def _extract_where_clause_blocks(
800829
where_clause = "WHERE " + " AND ".join(wc_blocks)
801830
return where_clause, vals_list
802831

803-
804-
def _get_search_cql(self, limit: int, columns: Optional[str] = CONTENT_COLUMNS, metadata: Optional[Dict[str, Any]] = {}, embedding: Optional[List[float]] = None) -> Tuple[str, Tuple[Any, ...]]:
805-
where_clause, get_cql_vals = self._extract_where_clause_blocks(metadata=metadata)
832+
def _get_search_cql(
833+
self,
834+
limit: int,
835+
columns: Optional[str] = CONTENT_COLUMNS,
836+
metadata: Optional[Dict[str, Any]] = {},
837+
embedding: Optional[List[float]] = None,
838+
) -> Tuple[str, Tuple[Any, ...]]:
839+
where_clause, get_cql_vals = self._extract_where_clause_blocks(
840+
metadata=metadata
841+
)
806842
limit_clause = "LIMIT ?"
807843
limit_cql_vals = [limit]
808844

809-
order_clause=""
845+
order_clause = ""
810846
order_cql_vals = []
811847
if embedding is not None:
812848
order_clause = "ORDER BY text_embedding ANN OF ?"
@@ -819,7 +855,6 @@ def _get_search_cql(self, limit: int, columns: Optional[str] = CONTENT_COLUMNS,
819855
where_clause=where_clause,
820856
order_clause=order_clause,
821857
limit_clause=limit_clause,
822-
823858
)
824859
prepared_query = self._session.prepare(select_cql)
825860
prepared_query.consistency_level = ConsistencyLevel.ONE

libs/knowledge-store/tests/integration_tests/test_graph_store.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# ruff: noqa: PT011, RUF015
2+
13
import secrets
24
from typing import Callable, Iterator, List, Optional
35

@@ -71,27 +73,28 @@ def test_graph_store_creation(graph_store_factory: Callable[[str], GraphStore])
7173
"""
7274
graph_store_factory()
7375

76+
7477
def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore]) -> None:
7578
gs = graph_store_factory()
7679

7780
gs.add_nodes([Node(text="bb1", id="row1")])
78-
gotten1 = gs.get_node(id="row1")
81+
gotten1 = gs.get_node(content_id="row1")
7982
assert gotten1 == Node(text="bb1", id="row1", metadata={})
8083

8184
gs.add_nodes([Node(text=None, id="row2", metadata={})])
82-
gotten2 = gs.get_node(id="row2")
85+
gotten2 = gs.get_node(content_id="row2")
8386
assert gotten2 == Node(text=None, id="row2", metadata={})
8487

8588
md3 = {"a": 1, "b": "Bee", "c": True}
8689
md3_string = {"a": "1.0", "b": "Bee", "c": "true"}
8790
gs.add_nodes([Node(text=None, id="row3", metadata=md3)])
88-
gotten3 = gs.get_node(id="row3")
91+
gotten3 = gs.get_node(content_id="row3")
8992
assert gotten3 == Node(text=None, id="row3", metadata=md3_string)
9093

9194
md4 = {"c1": True, "c2": True, "c3": True}
9295
md4_string = {"c1": "true", "c2": "true", "c3": "true"}
9396
gs.add_nodes([Node(text=None, id="row4", metadata=md4)])
94-
gotten4 = gs.get_node(id="row4")
97+
gotten4 = gs.get_node(content_id="row4")
9598
assert gotten4 == Node(text=None, id="row4", metadata=md4_string)
9699

97100
# metadata searches:
@@ -103,13 +106,15 @@ def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore])
103106
assert md_gotten4 == gotten4
104107

105108
# 'search' proper
106-
gs.add_nodes([
107-
Node(text=None, id="twin_a", metadata={"twin": True, "index": 0}),
108-
Node(text=None, id="twin_b", metadata={"twin": True, "index": 1})
109-
])
109+
gs.add_nodes(
110+
[
111+
Node(text=None, id="twin_a", metadata={"twin": True, "index": 0}),
112+
Node(text=None, id="twin_b", metadata={"twin": True, "index": 1}),
113+
]
114+
)
110115
md_twins_gotten = sorted(
111-
list(gs.metadata_search(metadata={"twin": True})),
112-
key=lambda res: int(float(res.metadata["index"]))
116+
gs.metadata_search(metadata={"twin": True}),
117+
key=lambda res: int(float(res.metadata["index"])),
113118
)
114119
expected = [
115120
Node(text=None, id="twin_a", metadata={"twin": "true", "index": "0.0"}),
@@ -118,24 +123,25 @@ def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore])
118123
assert md_twins_gotten == expected
119124
assert list(gs.metadata_search(metadata={"fake": True})) == []
120125

121-
def test_graph_store_metadata_routing(graph_store_factory: Callable[[str], GraphStore]) -> None:
126+
127+
def test_graph_store_metadata_routing(
128+
graph_store_factory: Callable[[str], GraphStore],
129+
) -> None:
122130
test_md = {"mds": "string", "mdn": 255, "mdb": True}
123131
test_md_string = {"mds": "string", "mdn": "255.0", "mdb": "true"}
124132

125133
gs_all = graph_store_factory(metadata_indexing="all")
126134
gs_all.add_nodes([Node(id="row1", text="bb1", metadata=test_md)])
127135
gotten_all = list(gs_all.metadata_search(metadata={"mds": "string"}))[0]
128136
assert gotten_all.metadata == test_md_string
129-
#
130137
gs_none = graph_store_factory(metadata_indexing="none")
131138
gs_none.add_nodes([Node(id="row1", text="bb1", metadata=test_md)])
132139
with pytest.raises(ValueError):
133140
# querying on non-indexed metadata fields:
134141
list(gs_none.metadata_search(metadata={"mds": "string"}))
135-
gotten_none = gs_none.get_node(id="row1")
142+
gotten_none = gs_none.get_node(content_id="row1")
136143
assert gotten_none is not None
137144
assert gotten_none.metadata == test_md_string
138-
#
139145
test_md_allowdeny = {
140146
"mdas": "MDAS",
141147
"mdds": "MDDS",
@@ -152,14 +158,14 @@ def test_graph_store_metadata_routing(graph_store_factory: Callable[[str], Graph
152158
"mdab": "true",
153159
"mddb": "true",
154160
}
155-
#
156-
gs_allow = graph_store_factory(metadata_indexing=("allow", {"mdas", "mdan", "mdab"}))
161+
gs_allow = graph_store_factory(
162+
metadata_indexing=("allow", {"mdas", "mdan", "mdab"})
163+
)
157164
gs_allow.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)])
158165
with pytest.raises(ValueError):
159166
list(gs_allow.metadata_search(metadata={"mdds": "MDDS"}))
160167
gotten_allow = list(gs_allow.metadata_search(metadata={"mdas": "MDAS"}))[0]
161168
assert gotten_allow.metadata == test_md_allowdeny_string
162-
#
163169
gs_deny = graph_store_factory(metadata_indexing=("deny", {"mdds", "mddn", "mddb"}))
164170
gs_deny.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)])
165171
with pytest.raises(ValueError):

libs/knowledge-store/tests/unit_tests/test_metadata_policy_normalization.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
1+
# ruff: noqa: SLF001
12
"""
23
Normalization of metadata policy specification options
34
"""
45

5-
from ragstack_knowledge_store.graph_store import MetadataIndexingMode, GraphStore
6+
from ragstack_knowledge_store.graph_store import GraphStore, MetadataIndexingMode
67

78

89
class TestNormalizeMetadataPolicy:
910
def test_normalize_metadata_policy(self) -> None:
10-
#
1111
mdp1 = GraphStore._normalize_metadata_indexing_policy("all")
1212
assert mdp1 == (MetadataIndexingMode.DEFAULT_TO_SEARCHABLE, set())
13-
#
1413
mdp2 = GraphStore._normalize_metadata_indexing_policy("none")
1514
assert mdp2 == (MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE, set())
16-
#
1715
mdp3 = GraphStore._normalize_metadata_indexing_policy(
1816
("default_to_Unsearchable", ["x", "y"]),
1917
)
2018
assert mdp3 == (MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE, {"x", "y"})
21-
#
2219
mdp4 = GraphStore._normalize_metadata_indexing_policy(
2320
("DenyList", ["z"]),
2421
)

0 commit comments

Comments
 (0)