Skip to content

Commit 432d422

Browse files
committed
Pipeline tests
1 parent ed00131 commit 432d422

File tree

2 files changed

+310
-3
lines changed

2 files changed

+310
-3
lines changed

tests/test_index.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Tests for index operation utilities."""
2+
13
import os
24
from collections.abc import Generator
35

@@ -34,7 +36,6 @@ def collection(client) -> Generator:
3436
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
3537
else:
3638
clxn = client[DB_NAME][COLLECTION_NAME]
37-
clxn = client[DB_NAME][COLLECTION_NAME]
3839
clxn.delete_many({})
3940
yield clxn
4041
clxn.delete_many({})
@@ -81,7 +82,6 @@ def test_search_index_update_vector_search_index(collection: Collection) -> None
8182
similarity_orig = "cosine"
8283
similarity_new = "euclidean"
8384

84-
# Create another index
8585
create_vector_search_index(
8686
collection=collection,
8787
index_name=index_name,
@@ -97,7 +97,6 @@ def test_search_index_update_vector_search_index(collection: Collection) -> None
9797
assert indexes[0]["name"] == index_name
9898
assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == similarity_orig
9999

100-
# Update the index and test new similarity
101100
update_vector_search_index(
102101
collection=collection,
103102
index_name=index_name,

tests/test_pipeline.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
"""Tests for pipeline aggregation generator utilities."""
2+
3+
from pymongo_vectorsearch_utils.pipeline import (
4+
combine_pipelines,
5+
final_hybrid_stage,
6+
reciprocal_rank_stage,
7+
text_search_stage,
8+
vector_search_stage,
9+
)
10+
11+
12+
class TestTextSearchStage:
13+
def test_basic_text_search(self):
14+
result = text_search_stage(
15+
query="test query", search_field="content", index_name="test_index"
16+
)
17+
18+
expected = [
19+
{
20+
"$search": {
21+
"index": "test_index",
22+
"text": {"query": "test query", "path": "content"},
23+
}
24+
},
25+
{"$set": {"score": {"$meta": "searchScore"}}},
26+
]
27+
28+
assert result == expected
29+
30+
def test_text_search_with_multiple_fields(self):
31+
result = text_search_stage(
32+
query="test query", search_field=["title", "content"], index_name="test_index"
33+
)
34+
35+
assert result[0]["$search"]["text"]["path"] == ["title", "content"]
36+
37+
def test_text_search_with_filter(self):
38+
filter_dict = {"category": "tech"}
39+
result = text_search_stage(
40+
query="test query", search_field="content", index_name="test_index", filter=filter_dict
41+
)
42+
43+
assert {"$match": filter_dict} in result
44+
45+
def test_text_search_with_limit(self):
46+
result = text_search_stage(
47+
query="test query", search_field="content", index_name="test_index", limit=10
48+
)
49+
50+
assert {"$limit": 10} in result
51+
52+
def test_text_search_without_scores(self):
53+
result = text_search_stage(
54+
query="test query",
55+
search_field="content",
56+
index_name="test_index",
57+
include_scores=False,
58+
)
59+
60+
score_stage = {"$set": {"score": {"$meta": "searchScore"}}}
61+
assert score_stage not in result
62+
63+
def test_text_search_with_all_parameters(self):
64+
filter_dict = {"status": "published"}
65+
result = text_search_stage(
66+
query="test query",
67+
search_field=["title", "description", "content"],
68+
index_name="test_index",
69+
limit=20,
70+
filter=filter_dict,
71+
include_scores=True,
72+
)
73+
74+
assert len(result) == 4
75+
assert result[0]["$search"]["index"] == "test_index"
76+
assert result[1] == {"$match": filter_dict}
77+
assert result[2] == {"$set": {"score": {"$meta": "searchScore"}}}
78+
assert result[3] == {"$limit": 20}
79+
80+
81+
class TestVectorSearchStage:
82+
def test_basic_vector_search(self):
83+
query_vector = [0.1, 0.2, 0.3, 0.4]
84+
result = vector_search_stage(
85+
query_vector=query_vector, search_field="embedding", index_name="vector_index"
86+
)
87+
88+
expected = {
89+
"$vectorSearch": {
90+
"index": "vector_index",
91+
"path": "embedding",
92+
"queryVector": query_vector,
93+
"numCandidates": 40,
94+
"limit": 4,
95+
}
96+
}
97+
98+
assert result == expected
99+
100+
def test_vector_search_with_custom_top_k(self):
101+
query_vector = [0.1, 0.2, 0.3]
102+
result = vector_search_stage(
103+
query_vector=query_vector, search_field="embedding", index_name="vector_index", top_k=10
104+
)
105+
106+
assert result["$vectorSearch"]["limit"] == 10
107+
assert result["$vectorSearch"]["numCandidates"] == 100
108+
109+
def test_vector_search_with_custom_oversampling(self):
110+
query_vector = [0.1, 0.2, 0.3]
111+
result = vector_search_stage(
112+
query_vector=query_vector,
113+
search_field="embedding",
114+
index_name="vector_index",
115+
top_k=5,
116+
oversampling_factor=20,
117+
)
118+
119+
assert result["$vectorSearch"]["numCandidates"] == 100
120+
121+
def test_vector_search_with_filter(self):
122+
query_vector = [0.1, 0.2, 0.3]
123+
filter_dict = {"metadata.category": "science"}
124+
result = vector_search_stage(
125+
query_vector=query_vector,
126+
search_field="embedding",
127+
index_name="vector_index",
128+
filter=filter_dict,
129+
)
130+
131+
assert result["$vectorSearch"]["filter"] == filter_dict
132+
133+
def test_vector_search_with_all_parameters(self):
134+
query_vector = [0.1, 0.2, 0.3, 0.4, 0.5]
135+
filter_dict = {"published": True, "language": "en"}
136+
result = vector_search_stage(
137+
query_vector=query_vector,
138+
search_field="text_embedding",
139+
index_name="content_vector_index",
140+
top_k=15,
141+
filter=filter_dict,
142+
oversampling_factor=8,
143+
)
144+
145+
expected = {
146+
"$vectorSearch": {
147+
"index": "content_vector_index",
148+
"path": "text_embedding",
149+
"queryVector": query_vector,
150+
"numCandidates": 120,
151+
"limit": 15,
152+
"filter": filter_dict,
153+
}
154+
}
155+
156+
assert result == expected
157+
158+
159+
class TestCombinePipelines:
160+
def test_combine_with_empty_pipeline(self):
161+
pipeline = []
162+
stage = [{"$match": {"field": "value"}}]
163+
164+
combine_pipelines(pipeline, stage, "test_collection")
165+
166+
assert pipeline == stage
167+
168+
def test_combine_with_existing_pipeline(self):
169+
pipeline = [{"$search": {"index": "test"}}]
170+
stage = [{"$vectorSearch": {"index": "vector_test"}}]
171+
172+
combine_pipelines(pipeline, stage, "test_collection")
173+
174+
expected_union = {"$unionWith": {"coll": "test_collection", "pipeline": stage}}
175+
176+
assert len(pipeline) == 2
177+
assert pipeline[1] == expected_union
178+
179+
def test_combine_modifies_in_place(self):
180+
original_pipeline = [{"$match": {"test": True}}]
181+
pipeline = original_pipeline.copy()
182+
stage = [{"$project": {"field": 1}}]
183+
184+
combine_pipelines(pipeline, stage, "collection")
185+
186+
assert len(original_pipeline) == 1
187+
assert len(pipeline) == 2
188+
189+
190+
class TestReciprocalRankStage:
191+
def test_basic_reciprocal_rank(self):
192+
result = reciprocal_rank_stage(score_field="text_score")
193+
194+
expected = [
195+
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
196+
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
197+
{
198+
"$addFields": {
199+
"docs.text_score": {"$divide": [1.0, {"$add": ["$rank", 0, 1]}]},
200+
"docs.rank": "$rank",
201+
"_id": "$docs._id",
202+
}
203+
},
204+
{"$replaceRoot": {"newRoot": "$docs"}},
205+
]
206+
207+
assert result == expected
208+
209+
def test_reciprocal_rank_with_penalty(self):
210+
result = reciprocal_rank_stage(score_field="vector_score", penalty=60)
211+
212+
add_fields_stage = result[2]["$addFields"]
213+
divide_expr = add_fields_stage["docs.vector_score"]["$divide"]
214+
add_expr = divide_expr[1]["$add"]
215+
216+
assert add_expr == ["$rank", 60, 1]
217+
218+
def test_reciprocal_rank_custom_score_field(self):
219+
result = reciprocal_rank_stage(score_field="custom_score_field")
220+
221+
add_fields_stage = result[2]["$addFields"]
222+
assert "docs.custom_score_field" in add_fields_stage
223+
224+
def test_reciprocal_rank_with_kwargs(self):
225+
result = reciprocal_rank_stage(score_field="test_score", penalty=10, extra_param="ignored")
226+
227+
assert len(result) == 4
228+
assert result[2]["$addFields"]["docs.test_score"]["$divide"][1]["$add"] == ["$rank", 10, 1]
229+
230+
231+
class TestFinalHybridStage:
232+
def test_basic_final_hybrid(self):
233+
result = final_hybrid_stage(scores_fields=["text_score", "vector_score"], limit=10)
234+
235+
expected = [
236+
{"$group": {"_id": "$_id", "docs": {"$mergeObjects": "$$ROOT"}}},
237+
{"$replaceRoot": {"newRoot": "$docs"}},
238+
{
239+
"$set": {
240+
"text_score": {"$ifNull": ["$text_score", 0]},
241+
"vector_score": {"$ifNull": ["$vector_score", 0]},
242+
}
243+
},
244+
{"$addFields": {"score": {"$add": ["$text_score", "$vector_score"]}}},
245+
{"$sort": {"score": -1}},
246+
{"$limit": 10},
247+
]
248+
249+
assert result == expected
250+
251+
def test_final_hybrid_single_score(self):
252+
result = final_hybrid_stage(scores_fields=["single_score"], limit=5)
253+
254+
set_stage = result[2]["$set"]
255+
assert set_stage == {"single_score": {"$ifNull": ["$single_score", 0]}}
256+
257+
add_fields_stage = result[3]["$addFields"]
258+
assert add_fields_stage == {"score": {"$add": ["$single_score"]}}
259+
260+
assert result[5] == {"$limit": 5}
261+
262+
def test_final_hybrid_multiple_scores(self):
263+
scores = ["text_score", "vector_score", "semantic_score"]
264+
result = final_hybrid_stage(scores_fields=scores, limit=20)
265+
266+
set_stage = result[2]["$set"]
267+
for score in scores:
268+
assert score in set_stage
269+
assert set_stage[score] == {"$ifNull": [f"${score}", 0]}
270+
271+
add_fields_stage = result[3]["$addFields"]
272+
expected_add = {"$add": [f"${score}" for score in scores]}
273+
assert add_fields_stage["score"] == expected_add
274+
275+
def test_final_hybrid_with_kwargs(self):
276+
result = final_hybrid_stage(scores_fields=["test_score"], limit=15, extra_param="ignored")
277+
278+
assert len(result) == 6
279+
assert result[5] == {"$limit": 15}
280+
281+
282+
class TestPipelineIntegration:
283+
def test_text_and_vector_pipeline_components(self):
284+
text_pipeline = text_search_stage(
285+
query="machine learning", search_field="content", index_name="text_index", limit=10
286+
)
287+
288+
vector_stage = vector_search_stage(
289+
query_vector=[0.1, 0.2, 0.3],
290+
search_field="embedding",
291+
index_name="vector_index",
292+
top_k=10,
293+
)
294+
295+
assert isinstance(text_pipeline, list)
296+
assert isinstance(vector_stage, dict)
297+
assert "$search" in text_pipeline[0]
298+
assert "$vectorSearch" in vector_stage
299+
300+
def test_rrf_and_final_stages_compatibility(self):
301+
rrf_stage = reciprocal_rank_stage(score_field="text_score")
302+
final_stage = final_hybrid_stage(scores_fields=["text_score", "vector_score"], limit=5)
303+
304+
rrf_field_creation = rrf_stage[2]["$addFields"]
305+
assert "docs.text_score" in rrf_field_creation
306+
307+
final_set_stage = final_stage[2]["$set"]
308+
assert "text_score" in final_set_stage

0 commit comments

Comments
 (0)