1
1
"""Tests for operation utilities."""
2
2
3
3
import os
4
- from unittest .mock import Mock
4
+ from unittest .mock import Mock , patch
5
5
6
6
import pytest
7
7
from bson import ObjectId
8
8
from pymongo import MongoClient
9
9
from pymongo .collection import Collection
10
10
11
- from pymongo_vectorsearch_utils .operation import bulk_embed_and_insert_texts
11
+ from pymongo_vectorsearch_utils import drop_vector_search_index
12
+ from pymongo_vectorsearch_utils .index import create_vector_search_index , wait_for_docs_in_index
13
+ from pymongo_vectorsearch_utils .operation import bulk_embed_and_insert_texts , execute_search_query
12
14
13
15
DB_NAME = "vectorsearch_utils_test"
14
16
COLLECTION_NAME = "test_operation"
17
+ VECTOR_INDEX_NAME = "operation_vector_index"
15
18
16
19
17
20
@pytest .fixture (scope = "module" )
@@ -21,6 +24,15 @@ def client():
21
24
yield client
22
25
client .close ()
23
26
27
+ @pytest .fixture (scope = "module" )
28
+ def preserved_collection (client ):
29
+ if COLLECTION_NAME not in client [DB_NAME ].list_collection_names ():
30
+ clxn = client [DB_NAME ].create_collection (COLLECTION_NAME )
31
+ else :
32
+ clxn = client [DB_NAME ][COLLECTION_NAME ]
33
+ clxn .delete_many ({})
34
+ yield clxn
35
+ clxn .delete_many ({})
24
36
25
37
@pytest .fixture
26
38
def collection (client ):
@@ -266,3 +278,176 @@ def test_custom_field_names(self, collection: Collection, mock_embedding_func):
266
278
assert "vector" in doc
267
279
assert doc ["content" ] == texts [0 ]
268
280
assert doc ["vector" ] == [0.0 , 0.0 , 0.0 ]
281
+
282
+
283
+ class TestExecuteSearchQuery :
284
+ @pytest .fixture (scope = "class" , autouse = True )
285
+ def vector_search_index (self , client ):
286
+ coll = client [DB_NAME ][COLLECTION_NAME ]
287
+ if len (coll .list_search_indexes (VECTOR_INDEX_NAME ).to_list ()) == 0 :
288
+ create_vector_search_index (
289
+ collection = coll ,
290
+ index_name = VECTOR_INDEX_NAME ,
291
+ dimensions = 3 ,
292
+ path = "embedding" ,
293
+ similarity = "cosine" ,
294
+ filters = ["category" , "color" , "wheels" ],
295
+ wait_until_complete = 120 ,
296
+ )
297
+ yield
298
+ drop_vector_search_index (collection = coll , index_name = VECTOR_INDEX_NAME )
299
+
300
+ @pytest .fixture (scope = "class" , autouse = True )
301
+ def sample_docs (self , preserved_collection : Collection ):
302
+ texts = ["apple fruit" , "banana fruit" , "car vehicle" , "bike vehicle" ]
303
+ metadatas = [
304
+ {"category" : "fruit" , "color" : "red" },
305
+ {"category" : "fruit" , "color" : "yellow" },
306
+ {"category" : "vehicle" , "wheels" : 4 },
307
+ {"category" : "vehicle" , "wheels" : 2 },
308
+ ]
309
+
310
+ def embeddings (texts ):
311
+ mapping = {
312
+ "apple fruit" : [1.0 , 0.5 , 0.0 ],
313
+ "banana fruit" : [0.5 , 0.5 , 0.0 ],
314
+ "car vehicle" : [0.0 , 0.5 , 1.0 ],
315
+ "bike vehicle" : [0.0 , 1.0 , 0.5 ],
316
+ }
317
+ return [mapping [text ] for text in texts ]
318
+
319
+ bulk_embed_and_insert_texts (
320
+ texts = texts ,
321
+ metadatas = metadatas ,
322
+ embedding_func = embeddings ,
323
+ collection = preserved_collection ,
324
+ text_key = "text" ,
325
+ embedding_key = "embedding" ,
326
+ )
327
+ # Add a document that should not be returned in searches
328
+ preserved_collection .insert_one ({'_id' : ObjectId ('68c1a038fd976373aa4ec19f' ), 'category' : 'fruit' , 'color' : 'red' , 'embedding' : [1.0 , 1.0 , 1.0 ]})
329
+ wait_for_docs_in_index (preserved_collection , VECTOR_INDEX_NAME , n_docs = 5 )
330
+ return preserved_collection
331
+
332
+ def test_basic_search_query (self , sample_docs : Collection ):
333
+ query_vector = [1.0 , 0.5 , 0.0 ]
334
+
335
+ result = execute_search_query (
336
+ query_vector = query_vector ,
337
+ collection = sample_docs ,
338
+ embedding_key = "embedding" ,
339
+ text_key = "text" ,
340
+ index_name = VECTOR_INDEX_NAME ,
341
+ k = 2 ,
342
+ )
343
+
344
+ assert len (result ) == 2
345
+ assert result [0 ]["text" ] == "apple fruit"
346
+ assert result [1 ]["text" ] == "banana fruit"
347
+ assert "score" in result [0 ]
348
+ assert "score" in result [1 ]
349
+
350
+ def test_search_with_pre_filter (self , sample_docs : Collection ):
351
+ query_vector = [1.0 , 0.5 , 1.0 ]
352
+ pre_filter = {"category" : "fruit" }
353
+
354
+ result = execute_search_query (
355
+ query_vector = query_vector ,
356
+ collection = sample_docs ,
357
+ embedding_key = "embedding" ,
358
+ text_key = "text" ,
359
+ index_name = VECTOR_INDEX_NAME ,
360
+ k = 4 ,
361
+ pre_filter = pre_filter ,
362
+ )
363
+
364
+ assert len (result ) == 2
365
+ assert result [0 ]["category" ] == "fruit"
366
+ assert result [1 ]["category" ] == "fruit"
367
+
368
+ def test_search_with_post_filter_pipeline (self , sample_docs : Collection ):
369
+ query_vector = [1.0 , 0.5 , 0.0 ]
370
+ post_filter_pipeline = [
371
+ {"$match" : {"score" : {"$gte" : 0.99 }}},
372
+ {"$sort" : {"score" : - 1 }},
373
+ ]
374
+
375
+ result = execute_search_query (
376
+ query_vector = query_vector ,
377
+ collection = sample_docs ,
378
+ embedding_key = "embedding" ,
379
+ text_key = "text" ,
380
+ index_name = VECTOR_INDEX_NAME ,
381
+ k = 2 ,
382
+ post_filter_pipeline = post_filter_pipeline ,
383
+ )
384
+
385
+ assert len (result ) == 1
386
+
387
+ def test_search_with_embeddings_included (self , sample_docs : Collection ):
388
+ query_vector = [1.0 , 0.5 , 0.0 ]
389
+
390
+ result = execute_search_query (
391
+ query_vector = query_vector ,
392
+ collection = sample_docs ,
393
+ embedding_key = "embedding" ,
394
+ text_key = "text" ,
395
+ index_name = VECTOR_INDEX_NAME ,
396
+ k = 1 ,
397
+ include_embeddings = True ,
398
+ )
399
+
400
+ assert len (result ) == 1
401
+ assert "embedding" in result [0 ]
402
+ assert result [0 ]["embedding" ] == [1.0 , 0.5 , 0.0 ]
403
+
404
+ def test_search_with_custom_field_names (self , sample_docs : Collection ):
405
+ query_vector = [1.0 , 0.5 , 0.25 ]
406
+
407
+ mock_cursor = [
408
+ {
409
+ "_id" : ObjectId (),
410
+ "content" : "apple fruit" ,
411
+ "vector" : [1.0 , 0.5 , 0.25 ],
412
+ "score" : 0.9 ,
413
+ }
414
+ ]
415
+
416
+ with patch .object (sample_docs , "aggregate" ) as mock_aggregate :
417
+ mock_aggregate .return_value = mock_cursor
418
+
419
+ result = execute_search_query (
420
+ query_vector = query_vector ,
421
+ collection = sample_docs ,
422
+ embedding_key = "vector" ,
423
+ text_key = "content" ,
424
+ index_name = VECTOR_INDEX_NAME ,
425
+ k = 1 ,
426
+ )
427
+
428
+ assert len (result ) == 1
429
+ assert "content" in result [0 ]
430
+ assert result [0 ]["content" ] == "apple fruit"
431
+
432
+ pipeline_arg = mock_aggregate .call_args [0 ][0 ]
433
+ vector_search_stage = pipeline_arg [0 ]["$vectorSearch" ]
434
+ assert vector_search_stage ["path" ] == "vector"
435
+ assert {"$project" : {"vector" : 0 }} in pipeline_arg
436
+
437
+ def test_search_filters_documents_without_text_key (self , sample_docs : Collection ):
438
+ query_vector = [1.0 , 0.5 , 0.0 ]
439
+
440
+ result = execute_search_query (
441
+ query_vector = query_vector ,
442
+ collection = sample_docs ,
443
+ embedding_key = "embedding" ,
444
+ text_key = "text" ,
445
+ index_name = VECTOR_INDEX_NAME ,
446
+ k = 3 ,
447
+ )
448
+
449
+ # Should only return documents with text field
450
+ assert len (result ) == 2
451
+ assert all ("text" in doc for doc in result )
452
+ assert result [0 ]["text" ] == "apple fruit"
453
+ assert result [1 ]["text" ] == "banana fruit"
0 commit comments