Skip to content

Commit 4746bd8

Browse files
committed
Add vector search test.
1 parent 165bda8 commit 4746bd8

File tree

4 files changed

+117
-69
lines changed

4 files changed

+117
-69
lines changed

django_mongodb_backend/compiler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,13 @@ def _build_aggregation_pipeline(self, ids, group):
248248
pipeline.append({"$unset": "_id"})
249249
return pipeline
250250

251-
def _compound_searches_queries(self, searches):
251+
def _compound_searches_queries(self, searches, search_replacements):
252252
if not searches:
253253
return []
254254
if len(searches) > 1:
255255
raise ValueError("Cannot perform more than one search operation.")
256-
return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": "searchScore"}}}]
256+
score_function = "searchScore" if "$search" in searches[0] else "vectorSearchScore"
257+
return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": score_function}}}]
257258

258259
def pre_sql_setup(self, with_col_aliases=False):
259260
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
@@ -262,7 +263,7 @@ def pre_sql_setup(self, with_col_aliases=False):
262263
)
263264
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
264265
all_replacements = {**search_replacements, **group_replacements}
265-
self.search_pipeline = self._compound_searches_queries(searches)
266+
self.search_pipeline = self._compound_searches_queries(searches, search_replacements)
266267
# query.group_by is either:
267268
# - None: no GROUP BY
268269
# - True: group by select fields

django_mongodb_backend/expressions/builtins.py

Lines changed: 62 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -598,61 +598,6 @@ def get_search_fields(self):
598598
return needed_fields
599599

600600

601-
class SearchVector(SearchExpression):
602-
def __init__(
603-
self,
604-
path,
605-
query_vector,
606-
limit,
607-
num_candidates=None,
608-
exact=None,
609-
filter=None,
610-
):
611-
self.path = path
612-
self.query_vector = query_vector
613-
self.limit = limit
614-
self.num_candidates = num_candidates
615-
self.exact = exact
616-
self.filter = filter
617-
super().__init__()
618-
619-
def __invert__(self):
620-
return ValueError("SearchVector cannot be negated")
621-
622-
def __and__(self, other):
623-
raise NotSupportedError("SearchVector cannot be combined")
624-
625-
def __rand__(self, other):
626-
raise NotSupportedError("SearchVector cannot be combined")
627-
628-
def __or__(self, other):
629-
raise NotSupportedError("SearchVector cannot be combined")
630-
631-
def __ror__(self, other):
632-
raise NotSupportedError("SearchVector cannot be combined")
633-
634-
def get_search_fields(self):
635-
return {self.path}
636-
637-
def _get_query_index(self, field, compiler):
638-
return "default"
639-
640-
def as_mql(self, compiler, connection):
641-
params = {
642-
"index": self._get_query_index(self.get_search_fields()),
643-
"path": self.path,
644-
"queryVector": self.query_vector,
645-
"limit": self.limit,
646-
}
647-
if self.num_candidates is not None:
648-
params["numCandidates"] = self.num_candidates
649-
if self.exact is not None:
650-
params["exact"] = self.exact
651-
if self.filter is not None:
652-
params["filter"] = self.filter
653-
return {"$vectorSearch": params}
654-
655-
656601
class CompoundExpression(SearchExpression):
657602
def __init__(
658603
self,
@@ -716,22 +661,76 @@ def resolve(node, negated=False):
716661
if operator == Operator.OR:
717662
return CompoundExpression(should=[lhs_compound, rhs_compound], minimum_should_match=1)
718663
if operator == Operator.AND:
719-
# NOTE: we can't just do the code below, think about this case (A | B) & (C | D)
720-
# return CompoundExpression(
721-
# must=lhs_compound.must + rhs_compound.must,
722-
# must_not=lhs_compound.must_not + rhs_compound.must_not,
723-
# should=lhs_compound.should + rhs_compound.should,
724-
# filter=lhs_compound.filter + rhs_compound.filter,
725-
# )
726664
return CompoundExpression(must=[lhs_compound, rhs_compound])
727-
# not operator
728665
return lhs_compound
729666

730667
def as_mql(self, compiler, connection):
731668
expression = self.resolve(self)
732669
return expression.as_mql(compiler, connection)
733670

734671

672+
class SearchVector(SearchExpression):
673+
def __init__(
674+
self,
675+
path,
676+
query_vector,
677+
limit,
678+
num_candidates=None,
679+
exact=None,
680+
filter=None,
681+
):
682+
self.path = path
683+
self.query_vector = query_vector
684+
self.limit = limit
685+
self.num_candidates = num_candidates
686+
self.exact = exact
687+
self.filter = filter
688+
super().__init__()
689+
690+
def __invert__(self):
691+
return ValueError("SearchVector cannot be negated")
692+
693+
def __and__(self, other):
694+
raise NotSupportedError("SearchVector cannot be combined")
695+
696+
def __rand__(self, other):
697+
raise NotSupportedError("SearchVector cannot be combined")
698+
699+
def __or__(self, other):
700+
raise NotSupportedError("SearchVector cannot be combined")
701+
702+
def __ror__(self, other):
703+
raise NotSupportedError("SearchVector cannot be combined")
704+
705+
def get_search_fields(self):
706+
return {self.path}
707+
708+
def _get_query_index(self, fields, compiler):
709+
for search_indexes in compiler.collection.list_search_indexes():
710+
if search_indexes["type"] == "vectorSearch":
711+
index_field = {
712+
field["path"] for field in search_indexes["latestDefinition"]["fields"]
713+
}
714+
if fields.issubset(index_field):
715+
return search_indexes["name"]
716+
return "default"
717+
718+
def as_mql(self, compiler, connection):
719+
params = {
720+
"index": self._get_query_index(self.get_search_fields(), compiler),
721+
"path": self.path,
722+
"queryVector": self.query_vector,
723+
"limit": self.limit,
724+
}
725+
if self.num_candidates is not None:
726+
params["numCandidates"] = self.num_candidates
727+
if self.exact is not None:
728+
params["exact"] = self.exact
729+
if self.filter is not None:
730+
params["filter"] = self.filter
731+
return {"$vectorSearch": params}
732+
733+
735734
class SearchScoreOption:
736735
"""Class to mutate scoring on a search operation"""
737736

tests/queries_/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from django.db import models
22

3-
from django_mongodb_backend.fields import ObjectIdAutoField, ObjectIdField
3+
from django_mongodb_backend.fields import ArrayField, ObjectIdAutoField, ObjectIdField
44

55

66
class Author(models.Model):
@@ -60,3 +60,4 @@ class Article(models.Model):
6060
number = models.IntegerField()
6161
body = models.TextField()
6262
location = models.JSONField(null=True)
63+
plot_embedding = ArrayField(models.FloatField(), size=3)

tests/queries_/test_search.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
SearchRange,
1818
SearchRegex,
1919
SearchText,
20+
SearchVector,
2021
SearchWildcard,
2122
)
2223

@@ -29,9 +30,9 @@ def _get_collection(model):
2930
return connection.database.get_collection(model._meta.db_table)
3031

3132
@staticmethod
32-
def create_search_index(model, index_name, definition):
33+
def create_search_index(model, index_name, definition, type="search"):
3334
collection = CreateIndexMixin._get_collection(model)
34-
idx = SearchIndexModel(definition=definition, name=index_name)
35+
idx = SearchIndexModel(definition=definition, name=index_name, type=type)
3536
collection.create_search_index(idx)
3637

3738

@@ -365,3 +366,49 @@ def test_compound_operations(self):
365366
)
366367
qs = Article.objects.annotate(score=expr)
367368
self.assertCountEqual(qs, [self.mars_mission, self.exoplanet])
369+
370+
371+
class SearchVectorTest(TestCase, CreateIndexMixin):
372+
@classmethod
373+
def setUpTestData(cls):
374+
cls.create_search_index(
375+
Article,
376+
"vector_index",
377+
{
378+
"fields": [
379+
{
380+
"type": "vector",
381+
"path": "plot_embedding",
382+
"numDimensions": 3,
383+
"similarity": "cosine",
384+
"quantization": "scalar",
385+
}
386+
]
387+
},
388+
type="vectorSearch",
389+
)
390+
391+
cls.mars = Article.objects.create(
392+
headline="Mars landing",
393+
number=1,
394+
body="The rover has landed on Mars",
395+
plot_embedding=[0.1, 0.2, 0.3],
396+
)
397+
Article.objects.create(
398+
headline="Cooking tips",
399+
number=2,
400+
body="This article is about pasta",
401+
plot_embedding=[0.9, 0.8, 0.7],
402+
)
403+
time.sleep(1)
404+
405+
def test_vector_search(self):
406+
vector_query = [0.1, 0.2, 0.3]
407+
expr = SearchVector(
408+
path="plot_embedding",
409+
query_vector=vector_query,
410+
num_candidates=5,
411+
limit=2,
412+
)
413+
qs = Article.objects.annotate(score=expr).order_by("-score")
414+
self.assertEqual(qs.first(), self.mars)

0 commit comments

Comments
 (0)