Skip to content

Commit 592ecb5

Browse files
authored
Move graph integration tests to libs/knowledge-store (#596)
* progress moving graph integration tests * simpiflied test * fix tests * fix lint
1 parent 8e2ffdc commit 592ecb5

File tree

2 files changed

+216
-295
lines changed

2 files changed

+216
-295
lines changed
Lines changed: 216 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,91 @@
1+
import math
12
import secrets
2-
from typing import Callable, Iterator, List
3+
from typing import Callable, Iterable, Iterator, List
34

5+
import numpy as np
46
import pytest
57
from dotenv import load_dotenv
68
from ragstack_knowledge_store import EmbeddingModel
7-
from ragstack_knowledge_store.graph_store import GraphStore
9+
from ragstack_knowledge_store.graph_store import GraphStore, Node
10+
from ragstack_knowledge_store.links import Link
811
from ragstack_tests_utils import LocalCassandraTestStore
912

1013
load_dotenv()
1114

1215
KEYSPACE = "default_keyspace"
1316

17+
vector_size = 52
1418

15-
@pytest.fixture(scope="session")
16-
def cassandra() -> Iterator[LocalCassandraTestStore]:
17-
store = LocalCassandraTestStore()
18-
yield store
1919

20-
if store.docker_container:
21-
store.docker_container.stop()
20+
def text_to_embedding(text: str) -> List[float]:
21+
"""Embeds text using a simple ascii conversion algorithm"""
22+
embedding = np.zeros(vector_size)
23+
for i, char in enumerate(text):
24+
if i >= vector_size - 2:
25+
break
26+
embedding[i + 2] = ord(char) / 255 # Normalize ASCII value
27+
vector: List[float] = embedding.tolist()
28+
return vector
2229

2330

24-
DUMMY_VECTOR = [0.1, 0.2]
31+
def angle_to_embedding(angle: float) -> List[float]:
32+
"""Embeds angles onto a circle"""
33+
embedding = np.zeros(vector_size)
34+
embedding[0] = math.cos(angle * math.pi)
35+
embedding[1] = math.sin(angle * math.pi)
36+
vector: List[float] = embedding.tolist()
37+
return vector
2538

2639

27-
class DummyEmbeddingModel(EmbeddingModel):
40+
class SimpleEmbeddingModel(EmbeddingModel):
41+
"""
42+
Embeds numeric values (as strings in units of pi) into two-dimensional vectors on
43+
a circle, and other text into a simple 50-dimension vector.
44+
"""
45+
2846
def embed_texts(self, texts: List[str]) -> List[List[float]]:
29-
return [DUMMY_VECTOR] * len(texts)
47+
"""
48+
Make a list of texts into a list of embedding vectors.
49+
"""
50+
return [self.embed_query(text) for text in texts]
3051

31-
def embed_query(self, _: str) -> List[float]:
32-
return DUMMY_VECTOR
52+
def embed_query(self, text: str) -> List[float]:
53+
"""
54+
Convert input text to a 'vector' (list of floats).
55+
If the text is a number, use it as the angle for the
56+
unit vector in units of pi.
57+
Any other input text is embedded as is.
58+
"""
59+
try:
60+
angle = float(text)
61+
return angle_to_embedding(angle=angle)
62+
except ValueError:
63+
# Assume: just test string
64+
return text_to_embedding(text)
3365

3466
async def aembed_texts(self, texts: List[str]) -> List[List[float]]:
35-
return [DUMMY_VECTOR] * len(texts)
67+
"""
68+
Make a list of texts into a list of embedding vectors.
69+
"""
70+
return self.embed_texts(texts=texts)
3671

37-
async def aembed_query(self, _: str) -> List[float]:
38-
return DUMMY_VECTOR
72+
async def aembed_query(self, text: str) -> List[float]:
73+
"""
74+
Convert input text to a 'vector' (list of floats).
75+
If the text is a number, use it as the angle for the
76+
unit vector in units of pi.
77+
Any other input text is embedded as is.
78+
"""
79+
return self.embed_query(text=text)
80+
81+
82+
@pytest.fixture(scope="session")
83+
def cassandra() -> Iterator[LocalCassandraTestStore]:
84+
store = LocalCassandraTestStore()
85+
yield store
86+
87+
if store.docker_container:
88+
store.docker_container.stop()
3989

4090

4191
@pytest.fixture()
@@ -45,7 +95,7 @@ def graph_store_factory(
4595
session = cassandra.create_cassandra_session()
4696
session.set_keyspace(KEYSPACE)
4797

48-
embedding = DummyEmbeddingModel()
98+
embedding = SimpleEmbeddingModel()
4999

50100
def _make_graph_store() -> GraphStore:
51101
name = secrets.token_hex(8)
@@ -63,9 +113,158 @@ def _make_graph_store() -> GraphStore:
63113
session.shutdown()
64114

65115

116+
def _result_ids(nodes: Iterable[Node]) -> List[str]:
117+
return [n.id for n in nodes if n.id is not None]
118+
119+
66120
def test_graph_store_creation(graph_store_factory: Callable[[], GraphStore]) -> None:
67121
"""Test that a graph store can be created.
68122
69123
This verifies the schema can be applied and the queries prepared.
70124
"""
71125
graph_store_factory()
126+
127+
128+
def test_mmr_traversal(graph_store_factory: Callable[[], GraphStore]) -> None:
129+
"""
130+
Test end to end construction and MMR search.
131+
The embedding function used here ensures `texts` become
132+
the following vectors on a circle (numbered v0 through v3):
133+
134+
______ v2
135+
/ \
136+
/ | v1
137+
v3 | . | query
138+
| / v0
139+
|______/ (N.B. very crude drawing)
140+
141+
With fetch_k==2 and k==2, when query is at (1, ),
142+
one expects that v2 and v0 are returned (in some order)
143+
because v1 is "too close" to v0 (and v0 is closer than v1)).
144+
145+
Both v2 and v3 are reachable via edges from v0, so once it is
146+
selected, those are both considered.
147+
"""
148+
gs = graph_store_factory()
149+
150+
v0 = Node(
151+
id="v0",
152+
text="-0.124",
153+
links={Link(direction="out", kind="explicit", tag="link")},
154+
)
155+
v1 = Node(
156+
id="v1",
157+
text="+0.127",
158+
)
159+
v2 = Node(
160+
id="v2",
161+
text="+0.25",
162+
links={Link(direction="in", kind="explicit", tag="link")},
163+
)
164+
v3 = Node(
165+
id="v3",
166+
text="+1.0",
167+
links={Link(direction="in", kind="explicit", tag="link")},
168+
)
169+
gs.add_nodes([v0, v1, v2, v3])
170+
171+
results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2)
172+
assert _result_ids(results) == ["v0", "v2"]
173+
174+
# With max depth 0, no edges are traversed, so this doesn't reach v2 or v3.
175+
# So it ends up picking "v1" even though it's similar to "v0".
176+
results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
177+
assert _result_ids(results) == ["v0", "v1"]
178+
179+
# With max depth 0 but higher `fetch_k`, we encounter v2
180+
results = gs.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
181+
assert _result_ids(results) == ["v0", "v2"]
182+
183+
# v0 score is .46, v2 score is 0.16 so it won't be chosen.
184+
results = gs.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
185+
assert _result_ids(results) == ["v0"]
186+
187+
# with k=4 we should get all of the documents.
188+
results = gs.mmr_traversal_search("0.0", k=4)
189+
assert _result_ids(results) == ["v0", "v2", "v1", "v3"]
190+
191+
192+
def test_write_retrieve_keywords(graph_store_factory: Callable[[], GraphStore]) -> None:
193+
gs = graph_store_factory()
194+
195+
greetings = Node(
196+
id="greetings",
197+
text="Typical Greetings",
198+
links={
199+
Link(direction="in", kind="parent", tag="parent"),
200+
},
201+
)
202+
doc1 = Node(
203+
id="doc1",
204+
text="Hello World",
205+
links={
206+
Link(direction="out", kind="parent", tag="parent"),
207+
Link(direction="bidir", kind="kw", tag="greeting"),
208+
Link(direction="bidir", kind="kw", tag="world"),
209+
},
210+
)
211+
doc2 = Node(
212+
id="doc2",
213+
text="Hello Earth",
214+
links={
215+
Link(direction="out", kind="parent", tag="parent"),
216+
Link(direction="bidir", kind="kw", tag="greeting"),
217+
Link(direction="bidir", kind="kw", tag="earth"),
218+
},
219+
)
220+
221+
gs.add_nodes([greetings, doc1, doc2])
222+
223+
# Doc2 is more similar, but World and Earth are similar enough that doc1 also shows
224+
# up.
225+
results = gs.similarity_search(text_to_embedding("Earth"), k=2)
226+
assert _result_ids(results) == ["doc2", "doc1"]
227+
228+
results = gs.similarity_search(text_to_embedding("Earth"), k=1)
229+
assert _result_ids(results) == ["doc2"]
230+
231+
results = gs.traversal_search("Earth", k=2, depth=0)
232+
assert _result_ids(results) == ["doc2", "doc1"]
233+
234+
results = gs.traversal_search("Earth", k=2, depth=1)
235+
assert _result_ids(results) == ["doc2", "doc1", "greetings"]
236+
237+
# K=1 only pulls in doc2 (Hello Earth)
238+
results = gs.traversal_search("Earth", k=1, depth=0)
239+
assert _result_ids(results) == ["doc2"]
240+
241+
# K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via keyword
242+
# edge.
243+
results = gs.traversal_search("Earth", k=1, depth=1)
244+
assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"}
245+
246+
247+
def test_metadata(graph_store_factory: Callable[[], GraphStore]) -> None:
248+
gs = graph_store_factory()
249+
250+
gs.add_nodes(
251+
[
252+
Node(
253+
id="a",
254+
text="A",
255+
links={
256+
Link(direction="in", kind="hyperlink", tag="http://a"),
257+
Link(direction="bidir", kind="other", tag="foo"),
258+
},
259+
metadata={"other": "some other field"},
260+
)
261+
]
262+
)
263+
results = list(gs.similarity_search(text_to_embedding("A")))
264+
assert len(results) == 1
265+
assert results[0].id == "a"
266+
assert results[0].metadata["other"] == "some other field"
267+
assert results[0].links == {
268+
Link(direction="in", kind="hyperlink", tag="http://a"),
269+
Link(direction="bidir", kind="other", tag="foo"),
270+
}

0 commit comments

Comments
 (0)