Skip to content

Commit ff61d05

Browse files
committed
fix type check
1 parent e7ef77f commit ff61d05

File tree

2 files changed

+33
-30
lines changed

2 files changed

+33
-30
lines changed

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class MetadataIndexingMode(Enum):
6767
DEFAULT_TO_SEARCHABLE = 2
6868

6969

70+
MetadataIndexingType = Union[Tuple[str, Iterable[str]], str]
7071
MetadataIndexingPolicy = Tuple[MetadataIndexingMode, Set[str]]
7172

7273

@@ -166,7 +167,7 @@ def __init__(
166167
session: Optional[Session] = None,
167168
keyspace: Optional[str] = None,
168169
setup_mode: SetupMode = SetupMode.SYNC,
169-
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
170+
metadata_indexing: MetadataIndexingType = "all",
170171
):
171172
session = check_resolve_session(session)
172173
keyspace = check_resolve_keyspace(keyspace)
@@ -414,7 +415,7 @@ def mmr_traversal_search(
414415
adjacent_k: int = 10,
415416
lambda_mult: float = 0.5,
416417
score_threshold: float = float("-inf"),
417-
metadata: Optional[Dict[str, Any]] = [],
418+
metadata: Dict[str, Any] = {},
418419
) -> Iterable[Node]:
419420
"""Retrieve documents from this graph store using MMR-traversal.
420421
@@ -540,7 +541,7 @@ def traversal_search(
540541
*,
541542
k: int = 4,
542543
depth: int = 1,
543-
metadata: Optional[Dict[str, Any]] = [],
544+
metadata: Dict[str, Any] = {},
544545
) -> Iterable[Node]:
545546
"""Retrieve documents from this knowledge store.
546547
@@ -657,7 +658,7 @@ def similarity_search(
657658
self,
658659
embedding: List[float],
659660
k: int = 4,
660-
metadata: Optional[Dict[str, Any]] = [],
661+
metadata: Dict[str, Any] = {},
661662
) -> Iterable[Node]:
662663
"""Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501
663664
query, params = self._get_search_cql(
@@ -668,7 +669,7 @@ def similarity_search(
668669
yield _row_to_node(row)
669670

670671
def metadata_search(
671-
self, metadata: Dict[str, Any] = {}, n: Optional[int] = 5
672+
self, metadata: Dict[str, Any] = {}, n: int = 5
672673
) -> Iterable[Node]:
673674
"""Retrieve nodes based on their metadata."""
674675
query, params = self._get_search_cql(metadata=metadata, limit=n)
@@ -833,7 +834,7 @@ def _get_search_cql(
833834
self,
834835
limit: int,
835836
columns: Optional[str] = CONTENT_COLUMNS,
836-
metadata: Optional[Dict[str, Any]] = {},
837+
metadata: Dict[str, Any] = {},
837838
embedding: Optional[List[float]] = None,
838839
) -> Tuple[str, Tuple[Any, ...]]:
839840
where_clause, get_cql_vals = self._extract_where_clause_blocks(

libs/knowledge-store/tests/integration_tests/test_graph_store.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# ruff: noqa: PT011, RUF015
22

33
import secrets
4-
from typing import Callable, Iterator, List, Optional
4+
from typing import Callable, Iterator, List
55

66
import pytest
77
from dotenv import load_dotenv
88
from ragstack_knowledge_store import EmbeddingModel
9-
from ragstack_knowledge_store.graph_store import GraphStore, Node
9+
from ragstack_knowledge_store.graph_store import GraphStore, MetadataIndexingType, Node
1010
from ragstack_tests_utils import LocalCassandraTestStore
1111

1212
load_dotenv()
@@ -49,7 +49,7 @@ def graph_store_factory(
4949

5050
embedding = DummyEmbeddingModel()
5151

52-
def _make_graph_store(metadata_indexing: Optional[str] = "all") -> GraphStore:
52+
def _make_graph_store(metadata_indexing: str = "all") -> GraphStore:
5353
name = secrets.token_hex(8)
5454

5555
node_table = f"nodes_{name}"
@@ -66,36 +66,40 @@ def _make_graph_store(metadata_indexing: Optional[str] = "all") -> GraphStore:
6666
session.shutdown()
6767

6868

69-
def test_graph_store_creation(graph_store_factory: Callable[[str], GraphStore]) -> None:
69+
def test_graph_store_creation(
70+
graph_store_factory: Callable[[MetadataIndexingType], GraphStore],
71+
) -> None:
7072
"""Test that a graph store can be created.
7173
7274
This verifies the schema can be applied and the queries prepared.
7375
"""
74-
graph_store_factory()
76+
graph_store_factory("all")
7577

7678

77-
def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore]) -> None:
78-
gs = graph_store_factory()
79+
def test_graph_store_metadata(
80+
graph_store_factory: Callable[[MetadataIndexingType], GraphStore],
81+
) -> None:
82+
gs = graph_store_factory("all")
7983

8084
gs.add_nodes([Node(text="bb1", id="row1")])
8185
gotten1 = gs.get_node(content_id="row1")
8286
assert gotten1 == Node(text="bb1", id="row1", metadata={})
8387

84-
gs.add_nodes([Node(text=None, id="row2", metadata={})])
88+
gs.add_nodes([Node(text="", id="row2", metadata={})])
8589
gotten2 = gs.get_node(content_id="row2")
86-
assert gotten2 == Node(text=None, id="row2", metadata={})
90+
assert gotten2 == Node(text="", id="row2", metadata={})
8791

8892
md3 = {"a": 1, "b": "Bee", "c": True}
8993
md3_string = {"a": "1.0", "b": "Bee", "c": "true"}
90-
gs.add_nodes([Node(text=None, id="row3", metadata=md3)])
94+
gs.add_nodes([Node(text="", id="row3", metadata=md3)])
9195
gotten3 = gs.get_node(content_id="row3")
92-
assert gotten3 == Node(text=None, id="row3", metadata=md3_string)
96+
assert gotten3 == Node(text="", id="row3", metadata=md3_string)
9397

9498
md4 = {"c1": True, "c2": True, "c3": True}
9599
md4_string = {"c1": "true", "c2": "true", "c3": "true"}
96-
gs.add_nodes([Node(text=None, id="row4", metadata=md4)])
100+
gs.add_nodes([Node(text="", id="row4", metadata=md4)])
97101
gotten4 = gs.get_node(content_id="row4")
98-
assert gotten4 == Node(text=None, id="row4", metadata=md4_string)
102+
assert gotten4 == Node(text="", id="row4", metadata=md4_string)
99103

100104
# metadata searches:
101105
md_gotten3a = list(gs.metadata_search(metadata={"a": 1}))[0]
@@ -108,33 +112,33 @@ def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore])
108112
# 'search' proper
109113
gs.add_nodes(
110114
[
111-
Node(text=None, id="twin_a", metadata={"twin": True, "index": 0}),
112-
Node(text=None, id="twin_b", metadata={"twin": True, "index": 1}),
115+
Node(text="", id="twin_a", metadata={"twin": True, "index": 0}),
116+
Node(text="", id="twin_b", metadata={"twin": True, "index": 1}),
113117
]
114118
)
115119
md_twins_gotten = sorted(
116120
gs.metadata_search(metadata={"twin": True}),
117121
key=lambda res: int(float(res.metadata["index"])),
118122
)
119123
expected = [
120-
Node(text=None, id="twin_a", metadata={"twin": "true", "index": "0.0"}),
121-
Node(text=None, id="twin_b", metadata={"twin": "true", "index": "1.0"}),
124+
Node(text="", id="twin_a", metadata={"twin": "true", "index": "0.0"}),
125+
Node(text="", id="twin_b", metadata={"twin": "true", "index": "1.0"}),
122126
]
123127
assert md_twins_gotten == expected
124128
assert list(gs.metadata_search(metadata={"fake": True})) == []
125129

126130

127131
def test_graph_store_metadata_routing(
128-
graph_store_factory: Callable[[str], GraphStore],
132+
graph_store_factory: Callable[[MetadataIndexingType], GraphStore],
129133
) -> None:
130134
test_md = {"mds": "string", "mdn": 255, "mdb": True}
131135
test_md_string = {"mds": "string", "mdn": "255.0", "mdb": "true"}
132136

133-
gs_all = graph_store_factory(metadata_indexing="all")
137+
gs_all = graph_store_factory("all")
134138
gs_all.add_nodes([Node(id="row1", text="bb1", metadata=test_md)])
135139
gotten_all = list(gs_all.metadata_search(metadata={"mds": "string"}))[0]
136140
assert gotten_all.metadata == test_md_string
137-
gs_none = graph_store_factory(metadata_indexing="none")
141+
gs_none = graph_store_factory("none")
138142
gs_none.add_nodes([Node(id="row1", text="bb1", metadata=test_md)])
139143
with pytest.raises(ValueError):
140144
# querying on non-indexed metadata fields:
@@ -158,15 +162,13 @@ def test_graph_store_metadata_routing(
158162
"mdab": "true",
159163
"mddb": "true",
160164
}
161-
gs_allow = graph_store_factory(
162-
metadata_indexing=("allow", {"mdas", "mdan", "mdab"})
163-
)
165+
gs_allow = graph_store_factory(("allow", {"mdas", "mdan", "mdab"}))
164166
gs_allow.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)])
165167
with pytest.raises(ValueError):
166168
list(gs_allow.metadata_search(metadata={"mdds": "MDDS"}))
167169
gotten_allow = list(gs_allow.metadata_search(metadata={"mdas": "MDAS"}))[0]
168170
assert gotten_allow.metadata == test_md_allowdeny_string
169-
gs_deny = graph_store_factory(metadata_indexing=("deny", {"mdds", "mddn", "mddb"}))
171+
gs_deny = graph_store_factory(("deny", {"mdds", "mddn", "mddb"}))
170172
gs_deny.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)])
171173
with pytest.raises(ValueError):
172174
list(gs_deny.metadata_search(metadata={"mdds": "MDDS"}))

0 commit comments

Comments
 (0)