Skip to content

Commit 9428c3d

Browse files
authored
Use session scoped fixtures instead of static variables (#599)
1 parent 46e2b48 commit 9428c3d

File tree

12 files changed

+105
-214
lines changed

12 files changed

+105
-214
lines changed

libs/colbert/pyproject.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@ torch = "2.2.1"
1717
cassio = "~0.1.7"
1818
pydantic = "^2.7.1"
1919

20+
# Workaround for https://github.com/pytorch/pytorch/pull/127921
21+
# Remove when we upgrade to pytorch 2.4
22+
setuptools = { version = ">=70", python = ">=3.12" }
23+
24+
2025
[tool.poetry.group.test.dependencies]
2126
ragstack-ai-tests-utils = { path = "../tests-utils", develop = true }
2227
pytest-asyncio = "^0.23.6"
2328

24-
[tool.poetry.group.dev.dependencies]
25-
setuptools = "70.0.0"
26-
29+
[tool.pytest.ini_options]
30+
asyncio_mode = "auto"
Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,24 @@
11
import pytest
2+
from cassandra.cluster import Session
23
from ragstack_tests_utils import AstraDBTestStore, LocalCassandraTestStore
34

4-
status = {
5-
"local_cassandra_test_store": None,
6-
"astradb_test_store": None,
7-
}
85

6+
@pytest.fixture(scope="session")
7+
def cassandra() -> LocalCassandraTestStore:
8+
store = LocalCassandraTestStore()
9+
yield store
10+
if store.docker_container:
11+
store.docker_container.stop()
912

10-
def get_local_cassandra_test_store():
11-
if not status["local_cassandra_test_store"]:
12-
status["local_cassandra_test_store"] = LocalCassandraTestStore()
13-
return status["local_cassandra_test_store"]
1413

14+
@pytest.fixture(scope="session")
15+
def astra_db() -> AstraDBTestStore:
16+
return AstraDBTestStore()
1517

16-
def get_astradb_test_store():
17-
if not status["astradb_test_store"]:
18-
status["astradb_test_store"] = AstraDBTestStore()
19-
return status["astradb_test_store"]
2018

21-
22-
@pytest.hookimpl()
23-
def pytest_sessionfinish():
24-
if (
25-
status["local_cassandra_test_store"]
26-
and status["local_cassandra_test_store"].docker_container
27-
):
28-
status["local_cassandra_test_store"].docker_container.stop()
19+
@pytest.fixture()
20+
def session(request) -> Session:
21+
test_store = request.getfixturevalue(request.param)
22+
session = test_store.create_cassandra_session()
23+
session.default_timeout = 180
24+
return session

libs/colbert/tests/integration_tests/test_database.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,11 @@
11
import pytest
2+
from cassandra.cluster import Session
23
from ragstack_colbert import CassandraDatabase, Chunk
34
from ragstack_tests_utils import TestData
45

5-
from tests.integration_tests.conftest import (
6-
get_astradb_test_store,
7-
get_local_cassandra_test_store,
8-
)
9-
10-
11-
@pytest.fixture()
12-
def cassandra():
13-
return get_local_cassandra_test_store()
14-
15-
16-
@pytest.fixture()
17-
def astra_db():
18-
return get_astradb_test_store()
19-
20-
21-
@pytest.mark.parametrize("vector_store", ["cassandra", "astra_db"])
22-
def test_database_sync(request, vector_store: str):
23-
vector_store = request.getfixturevalue(vector_store)
246

7+
@pytest.mark.parametrize("session", ["cassandra", "astra_db"], indirect=["session"])
8+
def test_database_sync(session: Session):
259
doc_id = "earth_doc_id"
2610

2711
chunk_0 = Chunk(
@@ -40,9 +24,6 @@ def test_database_sync(request, vector_store: str):
4024
embedding=TestData.renewable_energy_embedding(),
4125
)
4226

43-
session = vector_store.create_cassandra_session()
44-
session.default_timeout = 180
45-
4627
database = CassandraDatabase.from_session(
4728
keyspace="default_keyspace",
4829
table_name="test_database_sync",
@@ -61,11 +42,8 @@ def test_database_sync(request, vector_store: str):
6142
assert result
6243

6344

64-
@pytest.mark.parametrize("vector_store", ["cassandra", "astra_db"])
65-
@pytest.mark.asyncio()
66-
async def test_database_async(request, vector_store: str):
67-
vector_store = request.getfixturevalue(vector_store)
68-
45+
@pytest.mark.parametrize("session", ["cassandra", "astra_db"], indirect=["session"])
46+
async def test_database_async(session: Session):
6947
doc_id = "earth_doc_id"
7048

7149
chunk_0 = Chunk(
@@ -84,9 +62,6 @@ async def test_database_async(request, vector_store: str):
8462
embedding=TestData.renewable_energy_embedding(),
8563
)
8664

87-
session = vector_store.create_cassandra_session()
88-
session.default_timeout = 180
89-
9065
database = CassandraDatabase.from_session(
9166
keyspace="default_keyspace",
9267
table_name="test_database_async",

libs/colbert/tests/integration_tests/test_embedding_retrieval.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,17 @@
11
import logging
22

33
import pytest
4+
from cassandra.cluster import Session
45
from ragstack_colbert import (
56
CassandraDatabase,
67
ColbertEmbeddingModel,
78
ColbertVectorStore,
89
)
910
from ragstack_tests_utils import TestData
1011

11-
from tests.integration_tests.conftest import (
12-
get_astradb_test_store,
13-
get_local_cassandra_test_store,
14-
)
15-
16-
17-
@pytest.fixture()
18-
def cassandra():
19-
return get_local_cassandra_test_store()
20-
2112

22-
@pytest.fixture()
23-
def astra_db():
24-
return get_astradb_test_store()
25-
26-
27-
@pytest.mark.parametrize("vector_store", ["cassandra", "astra_db"])
28-
def test_embedding_cassandra_retriever(request, vector_store: str):
29-
vector_store = request.getfixturevalue(vector_store)
13+
@pytest.mark.parametrize("session", ["cassandra", "astra_db"], indirect=["session"])
14+
def test_embedding_cassandra_retriever(session: Session):
3015
narrative = TestData.marine_animals_text()
3116

3217
# Define the desired chunk size and overlap size
@@ -53,9 +38,6 @@ def chunk_texts(text, chunk_size, overlap_size):
5338

5439
doc_id = "marine_animals"
5540

56-
session = vector_store.create_cassandra_session()
57-
session.default_timeout = 180
58-
5941
database = CassandraDatabase.from_session(
6042
keyspace="default_keyspace",
6143
table_name="test_embedding_cassandra_retriever",

libs/knowledge-store/pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ mypy = "^1.10.0"
2020
pytest-asyncio = "^0.23.6"
2121
ipykernel = "^6.29.4"
2222
testcontainers = "~3.7.1"
23-
setuptools = "^70.0.0"
2423
python-dotenv = "^1.0.1"
2524

2625
# Resolve numpy version for 3.8 to 3.12+

libs/knowledge-store/tests/integration_tests/test_graph_store.py

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
import secrets
3-
from typing import Callable, Iterable, Iterator, List
3+
from typing import Iterable, Iterator, List
44

55
import numpy as np
66
import pytest
@@ -89,26 +89,25 @@ def cassandra() -> Iterator[LocalCassandraTestStore]:
8989

9090

9191
@pytest.fixture()
92-
def graph_store_factory(
92+
def graph_store(
9393
cassandra: LocalCassandraTestStore,
94-
) -> Iterator[Callable[[], GraphStore]]:
94+
) -> Iterator[GraphStore]:
9595
session = cassandra.create_cassandra_session()
9696
session.set_keyspace(KEYSPACE)
9797

9898
embedding = SimpleEmbeddingModel()
9999

100-
def _make_graph_store() -> GraphStore:
101-
name = secrets.token_hex(8)
100+
name = secrets.token_hex(8)
102101

103-
node_table = f"nodes_{name}"
104-
return GraphStore(
105-
embedding,
106-
session=session,
107-
keyspace=KEYSPACE,
108-
node_table=node_table,
109-
)
102+
node_table = f"nodes_{name}"
103+
store = GraphStore(
104+
embedding,
105+
session=session,
106+
keyspace=KEYSPACE,
107+
node_table=node_table,
108+
)
110109

111-
yield _make_graph_store
110+
yield store
112111

113112
session.shutdown()
114113

@@ -117,15 +116,7 @@ def _result_ids(nodes: Iterable[Node]) -> List[str]:
117116
return [n.id for n in nodes if n.id is not None]
118117

119118

120-
def test_graph_store_creation(graph_store_factory: Callable[[], GraphStore]) -> None:
121-
"""Test that a graph store can be created.
122-
123-
This verifies the schema can be applied and the queries prepared.
124-
"""
125-
graph_store_factory()
126-
127-
128-
def test_mmr_traversal(graph_store_factory: Callable[[], GraphStore]) -> None:
119+
def test_mmr_traversal(graph_store: GraphStore) -> None:
129120
"""
130121
Test end to end construction and MMR search.
131122
The embedding function used here ensures `texts` become
@@ -145,8 +136,6 @@ def test_mmr_traversal(graph_store_factory: Callable[[], GraphStore]) -> None:
145136
Both v2 and v3 are reachable via edges from v0, so once it is
146137
selected, those are both considered.
147138
"""
148-
gs = graph_store_factory()
149-
150139
v0 = Node(
151140
id="v0",
152141
text="-0.124",
@@ -166,32 +155,30 @@ def test_mmr_traversal(graph_store_factory: Callable[[], GraphStore]) -> None:
166155
text="+1.0",
167156
links={Link(direction="in", kind="explicit", tag="link")},
168157
)
169-
gs.add_nodes([v0, v1, v2, v3])
158+
graph_store.add_nodes([v0, v1, v2, v3])
170159

171-
results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2)
160+
results = graph_store.mmr_traversal_search("0.0", k=2, fetch_k=2)
172161
assert _result_ids(results) == ["v0", "v2"]
173162

174163
# With max depth 0, no edges are traversed, so this doesn't reach v2 or v3.
175164
# So it ends up picking "v1" even though it's similar to "v0".
176-
results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
165+
results = graph_store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
177166
assert _result_ids(results) == ["v0", "v1"]
178167

179168
# With max depth 0 but higher `fetch_k`, we encounter v2
180-
results = gs.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
169+
results = graph_store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
181170
assert _result_ids(results) == ["v0", "v2"]
182171

183172
# v0 score is .46, v2 score is 0.16 so it won't be chosen.
184-
results = gs.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
173+
results = graph_store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
185174
assert _result_ids(results) == ["v0"]
186175

187176
# with k=4 we should get all of the documents.
188-
results = gs.mmr_traversal_search("0.0", k=4)
177+
results = graph_store.mmr_traversal_search("0.0", k=4)
189178
assert _result_ids(results) == ["v0", "v2", "v1", "v3"]
190179

191180

192-
def test_write_retrieve_keywords(graph_store_factory: Callable[[], GraphStore]) -> None:
193-
gs = graph_store_factory()
194-
181+
def test_write_retrieve_keywords(graph_store: GraphStore) -> None:
195182
greetings = Node(
196183
id="greetings",
197184
text="Typical Greetings",
@@ -218,36 +205,34 @@ def test_write_retrieve_keywords(graph_store_factory: Callable[[], GraphStore])
218205
},
219206
)
220207

221-
gs.add_nodes([greetings, doc1, doc2])
208+
graph_store.add_nodes([greetings, doc1, doc2])
222209

223210
# Doc2 is more similar, but World and Earth are similar enough that doc1 also shows
224211
# up.
225-
results = gs.similarity_search(text_to_embedding("Earth"), k=2)
212+
results = graph_store.similarity_search(text_to_embedding("Earth"), k=2)
226213
assert _result_ids(results) == ["doc2", "doc1"]
227214

228-
results = gs.similarity_search(text_to_embedding("Earth"), k=1)
215+
results = graph_store.similarity_search(text_to_embedding("Earth"), k=1)
229216
assert _result_ids(results) == ["doc2"]
230217

231-
results = gs.traversal_search("Earth", k=2, depth=0)
218+
results = graph_store.traversal_search("Earth", k=2, depth=0)
232219
assert _result_ids(results) == ["doc2", "doc1"]
233220

234-
results = gs.traversal_search("Earth", k=2, depth=1)
221+
results = graph_store.traversal_search("Earth", k=2, depth=1)
235222
assert _result_ids(results) == ["doc2", "doc1", "greetings"]
236223

237224
# K=1 only pulls in doc2 (Hello Earth)
238-
results = gs.traversal_search("Earth", k=1, depth=0)
225+
results = graph_store.traversal_search("Earth", k=1, depth=0)
239226
assert _result_ids(results) == ["doc2"]
240227

241228
# K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via keyword
242229
# edge.
243-
results = gs.traversal_search("Earth", k=1, depth=1)
230+
results = graph_store.traversal_search("Earth", k=1, depth=1)
244231
assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"}
245232

246233

247-
def test_metadata(graph_store_factory: Callable[[], GraphStore]) -> None:
248-
gs = graph_store_factory()
249-
250-
gs.add_nodes(
234+
def test_metadata(graph_store: GraphStore) -> None:
235+
graph_store.add_nodes(
251236
[
252237
Node(
253238
id="a",
@@ -260,7 +245,7 @@ def test_metadata(graph_store_factory: Callable[[], GraphStore]) -> None:
260245
)
261246
]
262247
)
263-
results = list(gs.similarity_search(text_to_embedding("A")))
248+
results = list(graph_store.similarity_search(text_to_embedding("A")))
264249
assert len(results) == 1
265250
assert results[0].id == "a"
266251
assert results[0].metadata["other"] == "some other field"

libs/langchain/pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,5 @@ pytest-asyncio = "^0.23.6"
4242
keybert = "^0.8.5"
4343
gliner = "^0.2.5"
4444

45-
[tool.poetry.group.dev.dependencies]
46-
setuptools = "^70.0.0"
47-
45+
[tool.pytest.ini_options]
46+
asyncio_mode = "auto"

0 commit comments

Comments
 (0)