Skip to content

Commit 5cb05c5

Browse files
committed
Add vector search test.
1 parent e3b1c15 commit 5cb05c5

File tree

4 files changed

+117
-69
lines changed

4 files changed

+117
-69
lines changed

django_mongodb_backend/compiler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,13 @@ def _build_aggregation_pipeline(self, ids, group):
248248
pipeline.append({"$unset": "_id"})
249249
return pipeline
250250

251-
def _compound_searches_queries(self, searches):
251+
def _compound_searches_queries(self, searches, search_replacements):
252252
if not searches:
253253
return []
254254
if len(searches) > 1:
255255
raise ValueError("Cannot perform more than one search operation.")
256-
return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": "searchScore"}}}]
256+
score_function = "searchScore" if "$search" in searches[0] else "vectorSearchScore"
257+
return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": score_function}}}]
257258

258259
def pre_sql_setup(self, with_col_aliases=False):
259260
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
@@ -262,7 +263,7 @@ def pre_sql_setup(self, with_col_aliases=False):
262263
)
263264
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
264265
all_replacements = {**search_replacements, **group_replacements}
265-
self.search_pipeline = self._compound_searches_queries(searches)
266+
self.search_pipeline = self._compound_searches_queries(searches, search_replacements)
266267
# query.group_by is either:
267268
# - None: no GROUP BY
268269
# - True: group by select fields

django_mongodb_backend/expressions/builtins.py

Lines changed: 62 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -593,61 +593,6 @@ def get_search_fields(self):
593593
return needed_fields
594594

595595

596-
class SearchVector(SearchExpression):
597-
def __init__(
598-
self,
599-
path,
600-
query_vector,
601-
limit,
602-
num_candidates=None,
603-
exact=None,
604-
filter=None,
605-
):
606-
self.path = path
607-
self.query_vector = query_vector
608-
self.limit = limit
609-
self.num_candidates = num_candidates
610-
self.exact = exact
611-
self.filter = filter
612-
super().__init__()
613-
614-
def __invert__(self):
615-
return ValueError("SearchVector cannot be negated")
616-
617-
def __and__(self, other):
618-
raise NotSupportedError("SearchVector cannot be combined")
619-
620-
def __rand__(self, other):
621-
raise NotSupportedError("SearchVector cannot be combined")
622-
623-
def __or__(self, other):
624-
raise NotSupportedError("SearchVector cannot be combined")
625-
626-
def __ror__(self, other):
627-
raise NotSupportedError("SearchVector cannot be combined")
628-
629-
def get_search_fields(self):
630-
return {self.path}
631-
632-
def _get_query_index(self, field, compiler):
633-
return "default"
634-
635-
def as_mql(self, compiler, connection):
636-
params = {
637-
"index": self._get_query_index(self.get_search_fields()),
638-
"path": self.path,
639-
"queryVector": self.query_vector,
640-
"limit": self.limit,
641-
}
642-
if self.num_candidates is not None:
643-
params["numCandidates"] = self.num_candidates
644-
if self.exact is not None:
645-
params["exact"] = self.exact
646-
if self.filter is not None:
647-
params["filter"] = self.filter
648-
return {"$vectorSearch": params}
649-
650-
651596
class CompoundExpression(SearchExpression):
652597
def __init__(
653598
self,
@@ -711,22 +656,76 @@ def resolve(node, negated=False):
711656
if operator == Operator.OR:
712657
return CompoundExpression(should=[lhs_compound, rhs_compound], minimum_should_match=1)
713658
if operator == Operator.AND:
714-
# NOTE: we can't just do the code below, think about this case (A | B) & (C | D)
715-
# return CompoundExpression(
716-
# must=lhs_compound.must + rhs_compound.must,
717-
# must_not=lhs_compound.must_not + rhs_compound.must_not,
718-
# should=lhs_compound.should + rhs_compound.should,
719-
# filter=lhs_compound.filter + rhs_compound.filter,
720-
# )
721659
return CompoundExpression(must=[lhs_compound, rhs_compound])
722-
# not operator
723660
return lhs_compound
724661

725662
def as_mql(self, compiler, connection):
726663
expression = self.resolve(self)
727664
return expression.as_mql(compiler, connection)
728665

729666

667+
class SearchVector(SearchExpression):
668+
def __init__(
669+
self,
670+
path,
671+
query_vector,
672+
limit,
673+
num_candidates=None,
674+
exact=None,
675+
filter=None,
676+
):
677+
self.path = path
678+
self.query_vector = query_vector
679+
self.limit = limit
680+
self.num_candidates = num_candidates
681+
self.exact = exact
682+
self.filter = filter
683+
super().__init__()
684+
685+
def __invert__(self):
686+
return ValueError("SearchVector cannot be negated")
687+
688+
def __and__(self, other):
689+
raise NotSupportedError("SearchVector cannot be combined")
690+
691+
def __rand__(self, other):
692+
raise NotSupportedError("SearchVector cannot be combined")
693+
694+
def __or__(self, other):
695+
raise NotSupportedError("SearchVector cannot be combined")
696+
697+
def __ror__(self, other):
698+
raise NotSupportedError("SearchVector cannot be combined")
699+
700+
def get_search_fields(self):
701+
return {self.path}
702+
703+
def _get_query_index(self, fields, compiler):
704+
for search_indexes in compiler.collection.list_search_indexes():
705+
if search_indexes["type"] == "vectorSearch":
706+
index_field = {
707+
field["path"] for field in search_indexes["latestDefinition"]["fields"]
708+
}
709+
if fields.issubset(index_field):
710+
return search_indexes["name"]
711+
return "default"
712+
713+
def as_mql(self, compiler, connection):
714+
params = {
715+
"index": self._get_query_index(self.get_search_fields(), compiler),
716+
"path": self.path,
717+
"queryVector": self.query_vector,
718+
"limit": self.limit,
719+
}
720+
if self.num_candidates is not None:
721+
params["numCandidates"] = self.num_candidates
722+
if self.exact is not None:
723+
params["exact"] = self.exact
724+
if self.filter is not None:
725+
params["filter"] = self.filter
726+
return {"$vectorSearch": params}
727+
728+
730729
class SearchScoreOption:
731730
"""Class to mutate scoring on a search operation"""
732731

tests/queries_/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from django.db import models
22

3-
from django_mongodb_backend.fields import ObjectIdAutoField, ObjectIdField
3+
from django_mongodb_backend.fields import ArrayField, ObjectIdAutoField, ObjectIdField
44

55

66
class Author(models.Model):
@@ -60,3 +60,4 @@ class Article(models.Model):
6060
number = models.IntegerField()
6161
body = models.TextField()
6262
location = models.JSONField(null=True)
63+
plot_embedding = ArrayField(models.FloatField(), size=3)

tests/queries_/test_search.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
SearchRange,
1818
SearchRegex,
1919
SearchText,
20+
SearchVector,
2021
SearchWildcard,
2122
)
2223

@@ -29,9 +30,9 @@ def _get_collection(model):
2930
return connection.database.get_collection(model._meta.db_table)
3031

3132
@staticmethod
32-
def create_search_index(model, index_name, definition):
33+
def create_search_index(model, index_name, definition, type="search"):
3334
collection = CreateIndexMixin._get_collection(model)
34-
idx = SearchIndexModel(definition=definition, name=index_name)
35+
idx = SearchIndexModel(definition=definition, name=index_name, type=type)
3536
collection.create_search_index(idx)
3637

3738

@@ -365,3 +366,49 @@ def test_compound_operations(self):
365366
)
366367
qs = Article.objects.annotate(score=expr)
367368
self.assertCountEqual(qs, [self.mars_mission, self.exoplanet])
369+
370+
371+
class SearchVectorTest(TestCase, CreateIndexMixin):
372+
@classmethod
373+
def setUpTestData(cls):
374+
cls.create_search_index(
375+
Article,
376+
"vector_index",
377+
{
378+
"fields": [
379+
{
380+
"type": "vector",
381+
"path": "plot_embedding",
382+
"numDimensions": 3,
383+
"similarity": "cosine",
384+
"quantization": "scalar",
385+
}
386+
]
387+
},
388+
type="vectorSearch",
389+
)
390+
391+
cls.mars = Article.objects.create(
392+
headline="Mars landing",
393+
number=1,
394+
body="The rover has landed on Mars",
395+
plot_embedding=[0.1, 0.2, 0.3],
396+
)
397+
Article.objects.create(
398+
headline="Cooking tips",
399+
number=2,
400+
body="This article is about pasta",
401+
plot_embedding=[0.9, 0.8, 0.7],
402+
)
403+
time.sleep(1)
404+
405+
def test_vector_search(self):
406+
vector_query = [0.1, 0.2, 0.3]
407+
expr = SearchVector(
408+
path="plot_embedding",
409+
query_vector=vector_query,
410+
num_candidates=5,
411+
limit=2,
412+
)
413+
qs = Article.objects.annotate(score=expr).order_by("-score")
414+
self.assertEqual(qs.first(), self.mars)

0 commit comments

Comments
 (0)