1
1
import math
2
2
import secrets
3
- from typing import Callable , Iterable , Iterator , List
3
+ from typing import Iterable , Iterator , List
4
4
5
5
import numpy as np
6
6
import pytest
@@ -89,26 +89,25 @@ def cassandra() -> Iterator[LocalCassandraTestStore]:
89
89
90
90
91
91
@pytest .fixture ()
92
- def graph_store_factory (
92
+ def graph_store (
93
93
cassandra : LocalCassandraTestStore ,
94
- ) -> Iterator [Callable [[], GraphStore ] ]:
94
+ ) -> Iterator [GraphStore ]:
95
95
session = cassandra .create_cassandra_session ()
96
96
session .set_keyspace (KEYSPACE )
97
97
98
98
embedding = SimpleEmbeddingModel ()
99
99
100
- def _make_graph_store () -> GraphStore :
101
- name = secrets .token_hex (8 )
100
+ name = secrets .token_hex (8 )
102
101
103
- node_table = f"nodes_{ name } "
104
- return GraphStore (
105
- embedding ,
106
- session = session ,
107
- keyspace = KEYSPACE ,
108
- node_table = node_table ,
109
- )
102
+ node_table = f"nodes_{ name } "
103
+ store = GraphStore (
104
+ embedding ,
105
+ session = session ,
106
+ keyspace = KEYSPACE ,
107
+ node_table = node_table ,
108
+ )
110
109
111
- yield _make_graph_store
110
+ yield store
112
111
113
112
session .shutdown ()
114
113
@@ -117,15 +116,7 @@ def _result_ids(nodes: Iterable[Node]) -> List[str]:
117
116
return [n .id for n in nodes if n .id is not None ]
118
117
119
118
120
- def test_graph_store_creation (graph_store_factory : Callable [[], GraphStore ]) -> None :
121
- """Test that a graph store can be created.
122
-
123
- This verifies the schema can be applied and the queries prepared.
124
- """
125
- graph_store_factory ()
126
-
127
-
128
- def test_mmr_traversal (graph_store_factory : Callable [[], GraphStore ]) -> None :
119
+ def test_mmr_traversal (graph_store : GraphStore ) -> None :
129
120
"""
130
121
Test end to end construction and MMR search.
131
122
The embedding function used here ensures `texts` become
@@ -145,8 +136,6 @@ def test_mmr_traversal(graph_store_factory: Callable[[], GraphStore]) -> None:
145
136
Both v2 and v3 are reachable via edges from v0, so once it is
146
137
selected, those are both considered.
147
138
"""
148
- gs = graph_store_factory ()
149
-
150
139
v0 = Node (
151
140
id = "v0" ,
152
141
text = "-0.124" ,
@@ -166,32 +155,30 @@ def test_mmr_traversal(graph_store_factory: Callable[[], GraphStore]) -> None:
166
155
text = "+1.0" ,
167
156
links = {Link (direction = "in" , kind = "explicit" , tag = "link" )},
168
157
)
169
- gs .add_nodes ([v0 , v1 , v2 , v3 ])
158
+ graph_store .add_nodes ([v0 , v1 , v2 , v3 ])
170
159
171
- results = gs .mmr_traversal_search ("0.0" , k = 2 , fetch_k = 2 )
160
+ results = graph_store .mmr_traversal_search ("0.0" , k = 2 , fetch_k = 2 )
172
161
assert _result_ids (results ) == ["v0" , "v2" ]
173
162
174
163
# With max depth 0, no edges are traversed, so this doesn't reach v2 or v3.
175
164
# So it ends up picking "v1" even though it's similar to "v0".
176
- results = gs .mmr_traversal_search ("0.0" , k = 2 , fetch_k = 2 , depth = 0 )
165
+ results = graph_store .mmr_traversal_search ("0.0" , k = 2 , fetch_k = 2 , depth = 0 )
177
166
assert _result_ids (results ) == ["v0" , "v1" ]
178
167
179
168
# With max depth 0 but higher `fetch_k`, we encounter v2
180
- results = gs .mmr_traversal_search ("0.0" , k = 2 , fetch_k = 3 , depth = 0 )
169
+ results = graph_store .mmr_traversal_search ("0.0" , k = 2 , fetch_k = 3 , depth = 0 )
181
170
assert _result_ids (results ) == ["v0" , "v2" ]
182
171
183
172
# v0 score is .46, v2 score is 0.16 so it won't be chosen.
184
- results = gs .mmr_traversal_search ("0.0" , k = 2 , score_threshold = 0.2 )
173
+ results = graph_store .mmr_traversal_search ("0.0" , k = 2 , score_threshold = 0.2 )
185
174
assert _result_ids (results ) == ["v0" ]
186
175
187
176
# with k=4 we should get all of the documents.
188
- results = gs .mmr_traversal_search ("0.0" , k = 4 )
177
+ results = graph_store .mmr_traversal_search ("0.0" , k = 4 )
189
178
assert _result_ids (results ) == ["v0" , "v2" , "v1" , "v3" ]
190
179
191
180
192
- def test_write_retrieve_keywords (graph_store_factory : Callable [[], GraphStore ]) -> None :
193
- gs = graph_store_factory ()
194
-
181
+ def test_write_retrieve_keywords (graph_store : GraphStore ) -> None :
195
182
greetings = Node (
196
183
id = "greetings" ,
197
184
text = "Typical Greetings" ,
@@ -218,36 +205,34 @@ def test_write_retrieve_keywords(graph_store_factory: Callable[[], GraphStore])
218
205
},
219
206
)
220
207
221
- gs .add_nodes ([greetings , doc1 , doc2 ])
208
+ graph_store .add_nodes ([greetings , doc1 , doc2 ])
222
209
223
210
# Doc2 is more similar, but World and Earth are similar enough that doc1 also shows
224
211
# up.
225
- results = gs .similarity_search (text_to_embedding ("Earth" ), k = 2 )
212
+ results = graph_store .similarity_search (text_to_embedding ("Earth" ), k = 2 )
226
213
assert _result_ids (results ) == ["doc2" , "doc1" ]
227
214
228
- results = gs .similarity_search (text_to_embedding ("Earth" ), k = 1 )
215
+ results = graph_store .similarity_search (text_to_embedding ("Earth" ), k = 1 )
229
216
assert _result_ids (results ) == ["doc2" ]
230
217
231
- results = gs .traversal_search ("Earth" , k = 2 , depth = 0 )
218
+ results = graph_store .traversal_search ("Earth" , k = 2 , depth = 0 )
232
219
assert _result_ids (results ) == ["doc2" , "doc1" ]
233
220
234
- results = gs .traversal_search ("Earth" , k = 2 , depth = 1 )
221
+ results = graph_store .traversal_search ("Earth" , k = 2 , depth = 1 )
235
222
assert _result_ids (results ) == ["doc2" , "doc1" , "greetings" ]
236
223
237
224
# K=1 only pulls in doc2 (Hello Earth)
238
- results = gs .traversal_search ("Earth" , k = 1 , depth = 0 )
225
+ results = graph_store .traversal_search ("Earth" , k = 1 , depth = 0 )
239
226
assert _result_ids (results ) == ["doc2" ]
240
227
241
228
# K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via keyword
242
229
# edge.
243
- results = gs .traversal_search ("Earth" , k = 1 , depth = 1 )
230
+ results = graph_store .traversal_search ("Earth" , k = 1 , depth = 1 )
244
231
assert set (_result_ids (results )) == {"doc2" , "doc1" , "greetings" }
245
232
246
233
247
- def test_metadata (graph_store_factory : Callable [[], GraphStore ]) -> None :
248
- gs = graph_store_factory ()
249
-
250
- gs .add_nodes (
234
+ def test_metadata (graph_store : GraphStore ) -> None :
235
+ graph_store .add_nodes (
251
236
[
252
237
Node (
253
238
id = "a" ,
@@ -260,7 +245,7 @@ def test_metadata(graph_store_factory: Callable[[], GraphStore]) -> None:
260
245
)
261
246
]
262
247
)
263
- results = list (gs .similarity_search (text_to_embedding ("A" )))
248
+ results = list (graph_store .similarity_search (text_to_embedding ("A" )))
264
249
assert len (results ) == 1
265
250
assert results [0 ].id == "a"
266
251
assert results [0 ].metadata ["other" ] == "some other field"
0 commit comments