Skip to content

Commit 57a43cc

Browse files
committed
Add tests
1 parent 458dd74 commit 57a43cc

File tree

3 files changed

+242
-5
lines changed

3 files changed

+242
-5
lines changed

pymongo_vectorsearch_utils/index.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ def create_fulltext_search_index(
259259
def wait_for_docs_in_index(
260260
collection: Collection[Any],
261261
index_name: str,
262-
embedding_field: str,
263262
n_docs: int,
264263
) -> bool:
265264
"""Wait until the given number of documents are indexed by the given index.
@@ -270,12 +269,16 @@ def wait_for_docs_in_index(
270269
embedding_field (str): The name of the document field containing embeddings.
271270
n_docs (int): The number of documents to expect in the index.
272271
"""
273-
query_vector = [0.0] * 1024 # Dummy vector
272+
index = collection.list_search_indexes(index_name).to_list()[0]
273+
num_dimensions = index["latestDefinition"]["fields"][0]["numDimensions"]
274+
field = index["latestDefinition"]["fields"][0]["path"]
275+
276+
query_vector = [0.001] * num_dimensions # Dummy vector
274277
query = [
275278
{
276279
"$vectorSearch": {
277280
"index": index_name,
278-
"path": embedding_field,
281+
"path": field,
279282
"queryVector": query_vector,
280283
"numCandidates": n_docs,
281284
"limit": n_docs,

pymongo_vectorsearch_utils/operation.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pymongo import ReplaceOne
88
from pymongo.synchronous.collection import Collection
99

10+
from pymongo_vectorsearch_utils.pipeline import vector_search_stage
1011
from pymongo_vectorsearch_utils.util import oid_to_str, str_to_oid
1112

1213

@@ -60,3 +61,51 @@ def bulk_embed_and_insert_texts(
6061
result = collection.bulk_write(operations)
6162
assert result.upserted_ids is not None
6263
return [oid_to_str(_id) for _id in result.upserted_ids.values()]
64+
65+
66+
def execute_search_query(
67+
query_vector: list[float],
68+
collection: Collection[Any],
69+
embedding_key: str,
70+
text_key: str,
71+
index_name: str,
72+
k: int = 4,
73+
pre_filter: dict[str, Any] | None = None,
74+
post_filter_pipeline: list[dict[str, Any]] | None = None,
75+
oversampling_factor: int = 10,
76+
include_embeddings: bool = False,
77+
**kwargs: Any,
78+
) -> list[tuple[Any, float]]:
79+
"""Execute a MongoDB vector search query."""
80+
81+
# Atlas Vector Search, potentially with filter
82+
pipeline = [
83+
vector_search_stage(
84+
query_vector,
85+
embedding_key,
86+
index_name,
87+
k,
88+
pre_filter,
89+
oversampling_factor,
90+
**kwargs,
91+
),
92+
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
93+
]
94+
95+
96+
# Remove embeddings unless requested.
97+
if not include_embeddings:
98+
pipeline.append({"$project": {embedding_key: 0}})
99+
# Post-processing
100+
if post_filter_pipeline is not None:
101+
pipeline.extend(post_filter_pipeline)
102+
103+
# Execution
104+
cursor = collection.aggregate(pipeline)
105+
docs = []
106+
107+
for doc in cursor:
108+
if text_key not in doc:
109+
continue
110+
docs.append(doc)
111+
return docs

tests/test_operation.py

Lines changed: 187 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
"""Tests for operation utilities."""
22

33
import os
4-
from unittest.mock import Mock
4+
from unittest.mock import Mock, patch
55

66
import pytest
77
from bson import ObjectId
88
from pymongo import MongoClient
99
from pymongo.collection import Collection
1010

11-
from pymongo_vectorsearch_utils.operation import bulk_embed_and_insert_texts
11+
from pymongo_vectorsearch_utils import drop_vector_search_index
12+
from pymongo_vectorsearch_utils.index import create_vector_search_index, wait_for_docs_in_index
13+
from pymongo_vectorsearch_utils.operation import bulk_embed_and_insert_texts, execute_search_query
1214

1315
DB_NAME = "vectorsearch_utils_test"
1416
COLLECTION_NAME = "test_operation"
17+
VECTOR_INDEX_NAME = "operation_vector_index"
1518

1619

1720
@pytest.fixture(scope="module")
@@ -21,6 +24,15 @@ def client():
2124
yield client
2225
client.close()
2326

27+
@pytest.fixture(scope="module")
28+
def preserved_collection(client):
29+
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
30+
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
31+
else:
32+
clxn = client[DB_NAME][COLLECTION_NAME]
33+
clxn.delete_many({})
34+
yield clxn
35+
clxn.delete_many({})
2436

2537
@pytest.fixture
2638
def collection(client):
@@ -266,3 +278,176 @@ def test_custom_field_names(self, collection: Collection, mock_embedding_func):
266278
assert "vector" in doc
267279
assert doc["content"] == texts[0]
268280
assert doc["vector"] == [0.0, 0.0, 0.0]
281+
282+
283+
class TestExecuteSearchQuery:
284+
@pytest.fixture(scope="class", autouse=True)
285+
def vector_search_index(self, client):
286+
coll = client[DB_NAME][COLLECTION_NAME]
287+
if len(coll.list_search_indexes(VECTOR_INDEX_NAME).to_list()) == 0:
288+
create_vector_search_index(
289+
collection=coll,
290+
index_name=VECTOR_INDEX_NAME,
291+
dimensions=3,
292+
path="embedding",
293+
similarity="cosine",
294+
filters=["category", "color", "wheels"],
295+
wait_until_complete=120,
296+
)
297+
yield
298+
drop_vector_search_index(collection=coll, index_name=VECTOR_INDEX_NAME)
299+
300+
@pytest.fixture(scope="class", autouse=True)
301+
def sample_docs(self, preserved_collection: Collection):
302+
texts = ["apple fruit", "banana fruit", "car vehicle", "bike vehicle"]
303+
metadatas = [
304+
{"category": "fruit", "color": "red"},
305+
{"category": "fruit", "color": "yellow"},
306+
{"category": "vehicle", "wheels": 4},
307+
{"category": "vehicle", "wheels": 2},
308+
]
309+
310+
def embeddings(texts):
311+
mapping = {
312+
"apple fruit": [1.0, 0.5, 0.0],
313+
"banana fruit": [0.5, 0.5, 0.0],
314+
"car vehicle": [0.0, 0.5, 1.0],
315+
"bike vehicle": [0.0, 1.0, 0.5],
316+
}
317+
return [mapping[text] for text in texts]
318+
319+
bulk_embed_and_insert_texts(
320+
texts=texts,
321+
metadatas=metadatas,
322+
embedding_func=embeddings,
323+
collection=preserved_collection,
324+
text_key="text",
325+
embedding_key="embedding",
326+
)
327+
# Add a document that should not be returned in searches
328+
preserved_collection.insert_one({'_id': ObjectId('68c1a038fd976373aa4ec19f'), 'category': 'fruit', 'color': 'red', 'embedding': [1.0, 1.0, 1.0]})
329+
wait_for_docs_in_index(preserved_collection, VECTOR_INDEX_NAME, n_docs=5)
330+
return preserved_collection
331+
332+
def test_basic_search_query(self, sample_docs: Collection):
333+
query_vector = [1.0, 0.5, 0.0]
334+
335+
result = execute_search_query(
336+
query_vector=query_vector,
337+
collection=sample_docs,
338+
embedding_key="embedding",
339+
text_key="text",
340+
index_name=VECTOR_INDEX_NAME,
341+
k=2,
342+
)
343+
344+
assert len(result) == 2
345+
assert result[0]["text"] == "apple fruit"
346+
assert result[1]["text"] == "banana fruit"
347+
assert "score" in result[0]
348+
assert "score" in result[1]
349+
350+
def test_search_with_pre_filter(self, sample_docs: Collection):
351+
query_vector = [1.0, 0.5, 1.0]
352+
pre_filter = {"category": "fruit"}
353+
354+
result = execute_search_query(
355+
query_vector=query_vector,
356+
collection=sample_docs,
357+
embedding_key="embedding",
358+
text_key="text",
359+
index_name=VECTOR_INDEX_NAME,
360+
k=4,
361+
pre_filter=pre_filter,
362+
)
363+
364+
assert len(result) == 2
365+
assert result[0]["category"] == "fruit"
366+
assert result[1]["category"] == "fruit"
367+
368+
def test_search_with_post_filter_pipeline(self, sample_docs: Collection):
369+
query_vector = [1.0, 0.5, 0.0]
370+
post_filter_pipeline = [
371+
{"$match": {"score": {"$gte": 0.99}}},
372+
{"$sort": {"score": -1}},
373+
]
374+
375+
result = execute_search_query(
376+
query_vector=query_vector,
377+
collection=sample_docs,
378+
embedding_key="embedding",
379+
text_key="text",
380+
index_name=VECTOR_INDEX_NAME,
381+
k=2,
382+
post_filter_pipeline=post_filter_pipeline,
383+
)
384+
385+
assert len(result) == 1
386+
387+
def test_search_with_embeddings_included(self, sample_docs: Collection):
388+
query_vector = [1.0, 0.5, 0.0]
389+
390+
result = execute_search_query(
391+
query_vector=query_vector,
392+
collection=sample_docs,
393+
embedding_key="embedding",
394+
text_key="text",
395+
index_name=VECTOR_INDEX_NAME,
396+
k=1,
397+
include_embeddings=True,
398+
)
399+
400+
assert len(result) == 1
401+
assert "embedding" in result[0]
402+
assert result[0]["embedding"] == [1.0, 0.5, 0.0]
403+
404+
def test_search_with_custom_field_names(self, sample_docs: Collection):
405+
query_vector = [1.0, 0.5, 0.25]
406+
407+
mock_cursor = [
408+
{
409+
"_id": ObjectId(),
410+
"content": "apple fruit",
411+
"vector": [1.0, 0.5, 0.25],
412+
"score": 0.9,
413+
}
414+
]
415+
416+
with patch.object(sample_docs, "aggregate") as mock_aggregate:
417+
mock_aggregate.return_value = mock_cursor
418+
419+
result = execute_search_query(
420+
query_vector=query_vector,
421+
collection=sample_docs,
422+
embedding_key="vector",
423+
text_key="content",
424+
index_name=VECTOR_INDEX_NAME,
425+
k=1,
426+
)
427+
428+
assert len(result) == 1
429+
assert "content" in result[0]
430+
assert result[0]["content"] == "apple fruit"
431+
432+
pipeline_arg = mock_aggregate.call_args[0][0]
433+
vector_search_stage = pipeline_arg[0]["$vectorSearch"]
434+
assert vector_search_stage["path"] == "vector"
435+
assert {"$project": {"vector": 0}} in pipeline_arg
436+
437+
def test_search_filters_documents_without_text_key(self, sample_docs: Collection):
438+
query_vector = [1.0, 0.5, 0.0]
439+
440+
result = execute_search_query(
441+
query_vector=query_vector,
442+
collection=sample_docs,
443+
embedding_key="embedding",
444+
text_key="text",
445+
index_name=VECTOR_INDEX_NAME,
446+
k=3,
447+
)
448+
449+
# Should only return documents with text field
450+
assert len(result) == 2
451+
assert all("text" in doc for doc in result)
452+
assert result[0]["text"] == "apple fruit"
453+
assert result[1]["text"] == "banana fruit"

0 commit comments

Comments
 (0)