Skip to content

Commit d239afa

Browse files
committed
Restore integration_tests/test_graph_store.py
1 parent 479846f commit d239afa

File tree

1 file changed

+281
-0
lines changed

1 file changed

+281
-0
lines changed
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
import math
2+
import secrets
3+
from typing import Iterable, List, Optional
4+
5+
import pytest
6+
from cassandra.cluster import Session
7+
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
8+
from langchain_core.documents import Document
9+
from langchain_core.embeddings import Embeddings
10+
from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link
11+
from ragstack_tests_utils.test_store import KEYSPACE
12+
13+
from .conftest import get_astradb_test_store, get_local_cassandra_test_store
14+
15+
16+
class GraphStoreFactory:
17+
def __init__(self, session: Session, keyspace: str, embedding: Embeddings) -> None:
18+
self.session = session
19+
self.keyspace = keyspace
20+
self.uid = secrets.token_hex(8)
21+
self.node_table = f"nodes_{self.uid}"
22+
self.targets_table = f"targets_{self.uid}"
23+
self.embedding = embedding
24+
self._store = None
25+
26+
def store(
27+
self,
28+
initial_documents: Iterable[Document] = (),
29+
ids: Optional[Iterable[str]] = None,
30+
embedding: Optional[Embeddings] = None,
31+
) -> CassandraGraphVectorStore:
32+
if initial_documents and self._store is not None:
33+
raise ValueError("Store already initialized")
34+
if self._store is None:
35+
self._store = CassandraGraphVectorStore.from_documents(
36+
initial_documents,
37+
embedding=embedding or self.embedding,
38+
session=self.session,
39+
keyspace=self.keyspace,
40+
node_table=self.node_table,
41+
targets_table=self.targets_table,
42+
ids=ids,
43+
)
44+
45+
return self._store
46+
47+
def drop(self):
48+
self.session.execute(f"DROP TABLE IF EXISTS {self.keyspace}.{self.node_table};")
49+
self.session.execute(
50+
f"DROP TABLE IF EXISTS {self.keyspace}.{self.targets_table};"
51+
)
52+
53+
54+
@pytest.fixture(scope="session")
55+
def openai_embedding() -> Embeddings:
56+
from langchain_openai import OpenAIEmbeddings
57+
58+
return OpenAIEmbeddings()
59+
60+
61+
@pytest.fixture()
62+
def cassandra(openai_embedding: Embeddings):
63+
vstore = get_local_cassandra_test_store()
64+
session = vstore.create_cassandra_session()
65+
gs_factory = GraphStoreFactory(
66+
session=session, keyspace=KEYSPACE, embedding=openai_embedding
67+
)
68+
yield gs_factory
69+
gs_factory.drop()
70+
71+
72+
@pytest.fixture()
73+
def astra_db(openai_embedding: Embeddings):
74+
vstore = get_astradb_test_store()
75+
session = vstore.create_cassandra_session()
76+
gs_factory = GraphStoreFactory(
77+
session=session, keyspace=KEYSPACE, embedding=openai_embedding
78+
)
79+
yield gs_factory
80+
gs_factory.drop()
81+
82+
83+
class AngularTwoDimensionalEmbeddings(Embeddings):
84+
"""
85+
From angles (as strings in units of pi) to unit embedding vectors on a circle.
86+
"""
87+
88+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
89+
"""
90+
Make a list of texts into a list of embedding vectors.
91+
"""
92+
return [self.embed_query(text) for text in texts]
93+
94+
def embed_query(self, text: str) -> List[float]:
95+
"""
96+
Convert input text to a 'vector' (list of floats).
97+
If the text is a number, use it as the angle for the
98+
unit vector in units of pi.
99+
Any other input text becomes the singular result [0, 0] !
100+
"""
101+
try:
102+
angle = float(text)
103+
return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
104+
except ValueError:
105+
# Assume: just test string, no attention is paid to values.
106+
return [0.0, 0.0]
107+
108+
109+
def _result_ids(docs: Iterable[Document]) -> List[str]:
110+
return [d.id for d in docs]
111+
112+
113+
@pytest.mark.parametrize("gs_factory", ["cassandra", "astra_db"])
114+
def test_mmr_traversal(request, gs_factory: str):
115+
"""
116+
Test end to end construction and MMR search.
117+
The embedding function used here ensures `texts` become
118+
the following vectors on a circle (numbered v0 through v3):
119+
120+
______ v2
121+
/ \
122+
/ | v1
123+
v3 | . | query
124+
| / v0
125+
|______/ (N.B. very crude drawing)
126+
127+
With fetch_k==2 and k==2, when query is at (1, ),
128+
one expects that v2 and v0 are returned (in some order)
129+
because v1 is "too close" to v0 (and v0 is closer than v1)).
130+
131+
Both v2 and v3 are reachable via edges from v0, so once it is
132+
selected, those are both considered.
133+
"""
134+
gs_factory = request.getfixturevalue(gs_factory)
135+
store = gs_factory.store(
136+
embedding=AngularTwoDimensionalEmbeddings(),
137+
)
138+
139+
v0 = Document(
140+
id="v0",
141+
page_content="-0.124",
142+
metadata={
143+
METADATA_LINKS_KEY: {
144+
Link.outgoing(kind="explicit", tag="link"),
145+
},
146+
},
147+
)
148+
v1 = Document(
149+
id="v1",
150+
page_content="+0.127",
151+
)
152+
v2 = Document(
153+
id="v2",
154+
page_content="+0.25",
155+
metadata={
156+
METADATA_LINKS_KEY: {
157+
Link.incoming(kind="explicit", tag="link"),
158+
},
159+
},
160+
)
161+
v3 = Document(
162+
id="v3",
163+
page_content="+1.0",
164+
metadata={
165+
METADATA_LINKS_KEY: {
166+
Link.incoming(kind="explicit", tag="link"),
167+
},
168+
},
169+
)
170+
store.add_documents([v0, v1, v2, v3])
171+
172+
results = store.mmr_traversal_search("0.0", k=2, fetch_k=2)
173+
assert _result_ids(results) == ["v0", "v2"]
174+
175+
# With max depth 0, no edges are traversed, so this doesn't reach v2 or v3.
176+
# So it ends up picking "v1" even though it's similar to "v0".
177+
results = store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
178+
assert _result_ids(results) == ["v0", "v1"]
179+
180+
# With max depth 0 but higher `fetch_k`, we encounter v2
181+
results = store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
182+
assert _result_ids(results) == ["v0", "v2"]
183+
184+
# v0 score is .46, v2 score is 0.16 so it won't be chosen.
185+
results = store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
186+
assert _result_ids(results) == ["v0"]
187+
188+
# with k=4 we should get all of the documents.
189+
results = store.mmr_traversal_search("0.0", k=4)
190+
assert _result_ids(results) == ["v0", "v2", "v1", "v3"]
191+
192+
193+
@pytest.mark.parametrize("gs_factory", ["cassandra", "astra_db"])
194+
def test_write_retrieve_keywords(request, gs_factory: str):
195+
gs_factory = request.getfixturevalue(gs_factory)
196+
greetings = Document(
197+
id="greetings",
198+
page_content="Typical Greetings",
199+
metadata={
200+
METADATA_LINKS_KEY: {
201+
Link.incoming(kind="parent", tag="parent"),
202+
},
203+
},
204+
)
205+
doc1 = Document(
206+
id="doc1",
207+
page_content="Hello World",
208+
metadata={
209+
METADATA_LINKS_KEY: {
210+
Link.outgoing(kind="parent", tag="parent"),
211+
Link.bidir(kind="kw", tag="greeting"),
212+
Link.bidir(kind="kw", tag="world"),
213+
},
214+
},
215+
)
216+
doc2 = Document(
217+
id="doc2",
218+
page_content="Hello Earth",
219+
metadata={
220+
METADATA_LINKS_KEY: {
221+
Link.outgoing(kind="parent", tag="parent"),
222+
Link.bidir(kind="kw", tag="greeting"),
223+
Link.bidir(kind="kw", tag="earth"),
224+
},
225+
},
226+
)
227+
228+
store = gs_factory.store([greetings, doc1, doc2])
229+
230+
# Doc2 is more similar, but World and Earth are similar enough that doc1 also shows
231+
# up.
232+
results = store.similarity_search("Earth", k=2)
233+
assert _result_ids(results) == ["doc2", "doc1"]
234+
235+
results = store.similarity_search("Earth", k=1)
236+
assert _result_ids(results) == ["doc2"]
237+
238+
results = store.traversal_search("Earth", k=2, depth=0)
239+
assert _result_ids(results) == ["doc2", "doc1"]
240+
241+
results = store.traversal_search("Earth", k=2, depth=1)
242+
assert _result_ids(results) == ["doc2", "doc1", "greetings"]
243+
244+
# K=1 only pulls in doc2 (Hello Earth)
245+
results = store.traversal_search("Earth", k=1, depth=0)
246+
assert _result_ids(results) == ["doc2"]
247+
248+
# K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via keyword
249+
# edge.
250+
results = store.traversal_search("Earth", k=1, depth=1)
251+
assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"}
252+
253+
254+
@pytest.mark.parametrize("gs_factory", ["cassandra", "astra_db"])
255+
def test_metadata(request, gs_factory: str):
256+
gs_factory: GraphStoreFactory = request.getfixturevalue(gs_factory)
257+
store = gs_factory.store(
258+
[
259+
Document(
260+
id="a",
261+
page_content="A",
262+
metadata={
263+
METADATA_LINKS_KEY: {
264+
Link.incoming(kind="hyperlink", tag="http://a"),
265+
Link.bidir(kind="other", tag="foo"),
266+
},
267+
"other": "some other field",
268+
},
269+
)
270+
]
271+
)
272+
results = store.similarity_search("A")
273+
assert len(results) == 1
274+
doc = results[0]
275+
metadata = doc.metadata
276+
assert metadata["other"] == "some other field"
277+
assert doc.id == "a"
278+
assert set(metadata[METADATA_LINKS_KEY]) == {
279+
Link.incoming(kind="hyperlink", tag="http://a"),
280+
Link.bidir(kind="other", tag="foo"),
281+
}

0 commit comments

Comments
 (0)