1
+ import math
1
2
import secrets
2
- from typing import Callable , Iterator , List
3
+ from typing import Callable , Iterable , Iterator , List
3
4
5
+ import numpy as np
4
6
import pytest
5
7
from dotenv import load_dotenv
6
8
from ragstack_knowledge_store import EmbeddingModel
7
- from ragstack_knowledge_store .graph_store import GraphStore
9
+ from ragstack_knowledge_store .graph_store import GraphStore , Node
10
+ from ragstack_knowledge_store .links import Link
8
11
from ragstack_tests_utils import LocalCassandraTestStore
9
12
10
13
load_dotenv ()
11
14
12
15
KEYSPACE = "default_keyspace"
13
16
17
+ vector_size = 52
14
18
15
- @pytest .fixture (scope = "session" )
16
- def cassandra () -> Iterator [LocalCassandraTestStore ]:
17
- store = LocalCassandraTestStore ()
18
- yield store
19
19
20
- if store .docker_container :
21
- store .docker_container .stop ()
20
+ def text_to_embedding (text : str ) -> List [float ]:
21
+ """Embeds text using a simple ascii conversion algorithm"""
22
+ embedding = np .zeros (vector_size )
23
+ for i , char in enumerate (text ):
24
+ if i >= vector_size - 2 :
25
+ break
26
+ embedding [i + 2 ] = ord (char ) / 255 # Normalize ASCII value
27
+ vector : List [float ] = embedding .tolist ()
28
+ return vector
22
29
23
30
24
- DUMMY_VECTOR = [0.1 , 0.2 ]
31
+ def angle_to_embedding (angle : float ) -> List [float ]:
32
+ """Embeds angles onto a circle"""
33
+ embedding = np .zeros (vector_size )
34
+ embedding [0 ] = math .cos (angle * math .pi )
35
+ embedding [1 ] = math .sin (angle * math .pi )
36
+ vector : List [float ] = embedding .tolist ()
37
+ return vector
25
38
26
39
27
- class DummyEmbeddingModel (EmbeddingModel ):
40
+ class SimpleEmbeddingModel (EmbeddingModel ):
41
+ """
42
+ Embeds numeric values (as strings in units of pi) into two-dimensional vectors on
43
+ a circle, and other text into a simple 50-dimension vector.
44
+ """
45
+
28
46
def embed_texts (self , texts : List [str ]) -> List [List [float ]]:
29
- return [DUMMY_VECTOR ] * len (texts )
47
+ """
48
+ Make a list of texts into a list of embedding vectors.
49
+ """
50
+ return [self .embed_query (text ) for text in texts ]
30
51
31
- def embed_query (self , _ : str ) -> List [float ]:
32
- return DUMMY_VECTOR
52
+ def embed_query (self , text : str ) -> List [float ]:
53
+ """
54
+ Convert input text to a 'vector' (list of floats).
55
+ If the text is a number, use it as the angle for the
56
+ unit vector in units of pi.
57
+ Any other input text is embedded as is.
58
+ """
59
+ try :
60
+ angle = float (text )
61
+ return angle_to_embedding (angle = angle )
62
+ except ValueError :
63
+ # Assume: just test string
64
+ return text_to_embedding (text )
33
65
34
66
async def aembed_texts (self , texts : List [str ]) -> List [List [float ]]:
35
- return [DUMMY_VECTOR ] * len (texts )
67
+ """
68
+ Make a list of texts into a list of embedding vectors.
69
+ """
70
+ return self .embed_texts (texts = texts )
36
71
37
- async def aembed_query (self , _ : str ) -> List [float ]:
38
- return DUMMY_VECTOR
72
+ async def aembed_query (self , text : str ) -> List [float ]:
73
+ """
74
+ Convert input text to a 'vector' (list of floats).
75
+ If the text is a number, use it as the angle for the
76
+ unit vector in units of pi.
77
+ Any other input text is embedded as is.
78
+ """
79
+ return self .embed_query (text = text )
80
+
81
+
82
+ @pytest .fixture (scope = "session" )
83
+ def cassandra () -> Iterator [LocalCassandraTestStore ]:
84
+ store = LocalCassandraTestStore ()
85
+ yield store
86
+
87
+ if store .docker_container :
88
+ store .docker_container .stop ()
39
89
40
90
41
91
@pytest .fixture ()
@@ -45,7 +95,7 @@ def graph_store_factory(
45
95
session = cassandra .create_cassandra_session ()
46
96
session .set_keyspace (KEYSPACE )
47
97
48
- embedding = DummyEmbeddingModel ()
98
+ embedding = SimpleEmbeddingModel ()
49
99
50
100
def _make_graph_store () -> GraphStore :
51
101
name = secrets .token_hex (8 )
@@ -63,9 +113,158 @@ def _make_graph_store() -> GraphStore:
63
113
session .shutdown ()
64
114
65
115
116
+ def _result_ids (nodes : Iterable [Node ]) -> List [str ]:
117
+ return [n .id for n in nodes if n .id is not None ]
118
+
119
+
66
120
def test_graph_store_creation (graph_store_factory : Callable [[], GraphStore ]) -> None :
67
121
"""Test that a graph store can be created.
68
122
69
123
This verifies the schema can be applied and the queries prepared.
70
124
"""
71
125
graph_store_factory ()
126
+
127
+
128
+ def test_mmr_traversal (graph_store_factory : Callable [[], GraphStore ]) -> None :
129
+ """
130
+ Test end to end construction and MMR search.
131
+ The embedding function used here ensures `texts` become
132
+ the following vectors on a circle (numbered v0 through v3):
133
+
134
+ ______ v2
135
+ / \
136
+ / | v1
137
+ v3 | . | query
138
+ | / v0
139
+ |______/ (N.B. very crude drawing)
140
+
141
+ With fetch_k==2 and k==2, when query is at (1, ),
142
+ one expects that v2 and v0 are returned (in some order)
143
+ because v1 is "too close" to v0 (and v0 is closer than v1)).
144
+
145
+ Both v2 and v3 are reachable via edges from v0, so once it is
146
+ selected, those are both considered.
147
+ """
148
+ gs = graph_store_factory ()
149
+
150
+ v0 = Node (
151
+ id = "v0" ,
152
+ text = "-0.124" ,
153
+ links = {Link (direction = "out" , kind = "explicit" , tag = "link" )},
154
+ )
155
+ v1 = Node (
156
+ id = "v1" ,
157
+ text = "+0.127" ,
158
+ )
159
+ v2 = Node (
160
+ id = "v2" ,
161
+ text = "+0.25" ,
162
+ links = {Link (direction = "in" , kind = "explicit" , tag = "link" )},
163
+ )
164
+ v3 = Node (
165
+ id = "v3" ,
166
+ text = "+1.0" ,
167
+ links = {Link (direction = "in" , kind = "explicit" , tag = "link" )},
168
+ )
169
+ gs .add_nodes ([v0 , v1 , v2 , v3 ])
170
+
171
+ results = gs .mmr_traversal_search ("0.0" , k = 2 , fetch_k = 2 )
172
+ assert _result_ids (results ) == ["v0" , "v2" ]
173
+
174
+ # With max depth 0, no edges are traversed, so this doesn't reach v2 or v3.
175
+ # So it ends up picking "v1" even though it's similar to "v0".
176
+ results = gs .mmr_traversal_search ("0.0" , k = 2 , fetch_k = 2 , depth = 0 )
177
+ assert _result_ids (results ) == ["v0" , "v1" ]
178
+
179
+ # With max depth 0 but higher `fetch_k`, we encounter v2
180
+ results = gs .mmr_traversal_search ("0.0" , k = 2 , fetch_k = 3 , depth = 0 )
181
+ assert _result_ids (results ) == ["v0" , "v2" ]
182
+
183
+ # v0 score is .46, v2 score is 0.16 so it won't be chosen.
184
+ results = gs .mmr_traversal_search ("0.0" , k = 2 , score_threshold = 0.2 )
185
+ assert _result_ids (results ) == ["v0" ]
186
+
187
+ # with k=4 we should get all of the documents.
188
+ results = gs .mmr_traversal_search ("0.0" , k = 4 )
189
+ assert _result_ids (results ) == ["v0" , "v2" , "v1" , "v3" ]
190
+
191
+
192
+ def test_write_retrieve_keywords (graph_store_factory : Callable [[], GraphStore ]) -> None :
193
+ gs = graph_store_factory ()
194
+
195
+ greetings = Node (
196
+ id = "greetings" ,
197
+ text = "Typical Greetings" ,
198
+ links = {
199
+ Link (direction = "in" , kind = "parent" , tag = "parent" ),
200
+ },
201
+ )
202
+ doc1 = Node (
203
+ id = "doc1" ,
204
+ text = "Hello World" ,
205
+ links = {
206
+ Link (direction = "out" , kind = "parent" , tag = "parent" ),
207
+ Link (direction = "bidir" , kind = "kw" , tag = "greeting" ),
208
+ Link (direction = "bidir" , kind = "kw" , tag = "world" ),
209
+ },
210
+ )
211
+ doc2 = Node (
212
+ id = "doc2" ,
213
+ text = "Hello Earth" ,
214
+ links = {
215
+ Link (direction = "out" , kind = "parent" , tag = "parent" ),
216
+ Link (direction = "bidir" , kind = "kw" , tag = "greeting" ),
217
+ Link (direction = "bidir" , kind = "kw" , tag = "earth" ),
218
+ },
219
+ )
220
+
221
+ gs .add_nodes ([greetings , doc1 , doc2 ])
222
+
223
+ # Doc2 is more similar, but World and Earth are similar enough that doc1 also shows
224
+ # up.
225
+ results = gs .similarity_search (text_to_embedding ("Earth" ), k = 2 )
226
+ assert _result_ids (results ) == ["doc2" , "doc1" ]
227
+
228
+ results = gs .similarity_search (text_to_embedding ("Earth" ), k = 1 )
229
+ assert _result_ids (results ) == ["doc2" ]
230
+
231
+ results = gs .traversal_search ("Earth" , k = 2 , depth = 0 )
232
+ assert _result_ids (results ) == ["doc2" , "doc1" ]
233
+
234
+ results = gs .traversal_search ("Earth" , k = 2 , depth = 1 )
235
+ assert _result_ids (results ) == ["doc2" , "doc1" , "greetings" ]
236
+
237
+ # K=1 only pulls in doc2 (Hello Earth)
238
+ results = gs .traversal_search ("Earth" , k = 1 , depth = 0 )
239
+ assert _result_ids (results ) == ["doc2" ]
240
+
241
+ # K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via keyword
242
+ # edge.
243
+ results = gs .traversal_search ("Earth" , k = 1 , depth = 1 )
244
+ assert set (_result_ids (results )) == {"doc2" , "doc1" , "greetings" }
245
+
246
+
247
+ def test_metadata (graph_store_factory : Callable [[], GraphStore ]) -> None :
248
+ gs = graph_store_factory ()
249
+
250
+ gs .add_nodes (
251
+ [
252
+ Node (
253
+ id = "a" ,
254
+ text = "A" ,
255
+ links = {
256
+ Link (direction = "in" , kind = "hyperlink" , tag = "http://a" ),
257
+ Link (direction = "bidir" , kind = "other" , tag = "foo" ),
258
+ },
259
+ metadata = {"other" : "some other field" },
260
+ )
261
+ ]
262
+ )
263
+ results = list (gs .similarity_search (text_to_embedding ("A" )))
264
+ assert len (results ) == 1
265
+ assert results [0 ].id == "a"
266
+ assert results [0 ].metadata ["other" ] == "some other field"
267
+ assert results [0 ].links == {
268
+ Link (direction = "in" , kind = "hyperlink" , tag = "http://a" ),
269
+ Link (direction = "bidir" , kind = "other" , tag = "foo" ),
270
+ }
0 commit comments