1
1
# ruff: noqa: PT011, RUF015
2
2
3
3
import secrets
4
- from typing import Callable , Iterator , List , Optional
4
+ from typing import Callable , Iterator , List
5
5
6
6
import pytest
7
7
from dotenv import load_dotenv
8
8
from ragstack_knowledge_store import EmbeddingModel
9
- from ragstack_knowledge_store .graph_store import GraphStore , Node
9
+ from ragstack_knowledge_store .graph_store import GraphStore , MetadataIndexingType , Node
10
10
from ragstack_tests_utils import LocalCassandraTestStore
11
11
12
12
load_dotenv ()
@@ -49,7 +49,7 @@ def graph_store_factory(
49
49
50
50
embedding = DummyEmbeddingModel ()
51
51
52
- def _make_graph_store (metadata_indexing : Optional [ str ] = "all" ) -> GraphStore :
52
+ def _make_graph_store (metadata_indexing : str = "all" ) -> GraphStore :
53
53
name = secrets .token_hex (8 )
54
54
55
55
node_table = f"nodes_{ name } "
@@ -66,36 +66,40 @@ def _make_graph_store(metadata_indexing: Optional[str] = "all") -> GraphStore:
66
66
session .shutdown ()
67
67
68
68
69
- def test_graph_store_creation (graph_store_factory : Callable [[str ], GraphStore ]) -> None :
69
+ def test_graph_store_creation (
70
+ graph_store_factory : Callable [[MetadataIndexingType ], GraphStore ],
71
+ ) -> None :
70
72
"""Test that a graph store can be created.
71
73
72
74
This verifies the schema can be applied and the queries prepared.
73
75
"""
74
- graph_store_factory ()
76
+ graph_store_factory ("all" )
75
77
76
78
77
- def test_graph_store_metadata (graph_store_factory : Callable [[str ], GraphStore ]) -> None :
78
- gs = graph_store_factory ()
79
+ def test_graph_store_metadata (
80
+ graph_store_factory : Callable [[MetadataIndexingType ], GraphStore ],
81
+ ) -> None :
82
+ gs = graph_store_factory ("all" )
79
83
80
84
gs .add_nodes ([Node (text = "bb1" , id = "row1" )])
81
85
gotten1 = gs .get_node (content_id = "row1" )
82
86
assert gotten1 == Node (text = "bb1" , id = "row1" , metadata = {})
83
87
84
- gs .add_nodes ([Node (text = None , id = "row2" , metadata = {})])
88
+ gs .add_nodes ([Node (text = "" , id = "row2" , metadata = {})])
85
89
gotten2 = gs .get_node (content_id = "row2" )
86
- assert gotten2 == Node (text = None , id = "row2" , metadata = {})
90
+ assert gotten2 == Node (text = "" , id = "row2" , metadata = {})
87
91
88
92
md3 = {"a" : 1 , "b" : "Bee" , "c" : True }
89
93
md3_string = {"a" : "1.0" , "b" : "Bee" , "c" : "true" }
90
- gs .add_nodes ([Node (text = None , id = "row3" , metadata = md3 )])
94
+ gs .add_nodes ([Node (text = "" , id = "row3" , metadata = md3 )])
91
95
gotten3 = gs .get_node (content_id = "row3" )
92
- assert gotten3 == Node (text = None , id = "row3" , metadata = md3_string )
96
+ assert gotten3 == Node (text = "" , id = "row3" , metadata = md3_string )
93
97
94
98
md4 = {"c1" : True , "c2" : True , "c3" : True }
95
99
md4_string = {"c1" : "true" , "c2" : "true" , "c3" : "true" }
96
- gs .add_nodes ([Node (text = None , id = "row4" , metadata = md4 )])
100
+ gs .add_nodes ([Node (text = "" , id = "row4" , metadata = md4 )])
97
101
gotten4 = gs .get_node (content_id = "row4" )
98
- assert gotten4 == Node (text = None , id = "row4" , metadata = md4_string )
102
+ assert gotten4 == Node (text = "" , id = "row4" , metadata = md4_string )
99
103
100
104
# metadata searches:
101
105
md_gotten3a = list (gs .metadata_search (metadata = {"a" : 1 }))[0 ]
@@ -108,33 +112,33 @@ def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore])
108
112
# 'search' proper
109
113
gs .add_nodes (
110
114
[
111
- Node (text = None , id = "twin_a" , metadata = {"twin" : True , "index" : 0 }),
112
- Node (text = None , id = "twin_b" , metadata = {"twin" : True , "index" : 1 }),
115
+ Node (text = "" , id = "twin_a" , metadata = {"twin" : True , "index" : 0 }),
116
+ Node (text = "" , id = "twin_b" , metadata = {"twin" : True , "index" : 1 }),
113
117
]
114
118
)
115
119
md_twins_gotten = sorted (
116
120
gs .metadata_search (metadata = {"twin" : True }),
117
121
key = lambda res : int (float (res .metadata ["index" ])),
118
122
)
119
123
expected = [
120
- Node (text = None , id = "twin_a" , metadata = {"twin" : "true" , "index" : "0.0" }),
121
- Node (text = None , id = "twin_b" , metadata = {"twin" : "true" , "index" : "1.0" }),
124
+ Node (text = "" , id = "twin_a" , metadata = {"twin" : "true" , "index" : "0.0" }),
125
+ Node (text = "" , id = "twin_b" , metadata = {"twin" : "true" , "index" : "1.0" }),
122
126
]
123
127
assert md_twins_gotten == expected
124
128
assert list (gs .metadata_search (metadata = {"fake" : True })) == []
125
129
126
130
127
131
def test_graph_store_metadata_routing (
128
- graph_store_factory : Callable [[str ], GraphStore ],
132
+ graph_store_factory : Callable [[MetadataIndexingType ], GraphStore ],
129
133
) -> None :
130
134
test_md = {"mds" : "string" , "mdn" : 255 , "mdb" : True }
131
135
test_md_string = {"mds" : "string" , "mdn" : "255.0" , "mdb" : "true" }
132
136
133
- gs_all = graph_store_factory (metadata_indexing = "all" )
137
+ gs_all = graph_store_factory ("all" )
134
138
gs_all .add_nodes ([Node (id = "row1" , text = "bb1" , metadata = test_md )])
135
139
gotten_all = list (gs_all .metadata_search (metadata = {"mds" : "string" }))[0 ]
136
140
assert gotten_all .metadata == test_md_string
137
- gs_none = graph_store_factory (metadata_indexing = "none" )
141
+ gs_none = graph_store_factory ("none" )
138
142
gs_none .add_nodes ([Node (id = "row1" , text = "bb1" , metadata = test_md )])
139
143
with pytest .raises (ValueError ):
140
144
# querying on non-indexed metadata fields:
@@ -158,15 +162,13 @@ def test_graph_store_metadata_routing(
158
162
"mdab" : "true" ,
159
163
"mddb" : "true" ,
160
164
}
161
- gs_allow = graph_store_factory (
162
- metadata_indexing = ("allow" , {"mdas" , "mdan" , "mdab" })
163
- )
165
+ gs_allow = graph_store_factory (("allow" , {"mdas" , "mdan" , "mdab" }))
164
166
gs_allow .add_nodes ([Node (id = "row1" , text = "bb1" , metadata = test_md_allowdeny )])
165
167
with pytest .raises (ValueError ):
166
168
list (gs_allow .metadata_search (metadata = {"mdds" : "MDDS" }))
167
169
gotten_allow = list (gs_allow .metadata_search (metadata = {"mdas" : "MDAS" }))[0 ]
168
170
assert gotten_allow .metadata == test_md_allowdeny_string
169
- gs_deny = graph_store_factory (metadata_indexing = ("deny" , {"mdds" , "mddn" , "mddb" }))
171
+ gs_deny = graph_store_factory (("deny" , {"mdds" , "mddn" , "mddb" }))
170
172
gs_deny .add_nodes ([Node (id = "row1" , text = "bb1" , metadata = test_md_allowdeny )])
171
173
with pytest .raises (ValueError ):
172
174
list (gs_deny .metadata_search (metadata = {"mdds" : "MDDS" }))
0 commit comments