Skip to content

Commit 6dac022

Browse files
committed
Add combinable test
1 parent 5cb05c5 commit 6dac022

File tree

5 files changed

+292
-102
lines changed

5 files changed

+292
-102
lines changed

django_mongodb_backend/compiler.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from django.utils.functional import cached_property
1818
from pymongo import ASCENDING, DESCENDING
1919

20-
from .expressions.builtins import SearchExpression
20+
from .expressions.builtins import SearchExpression, SearchVector
2121
from .query import MongoQuery, wrap_database_errors
2222

2323

@@ -109,37 +109,26 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
109109
replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias)
110110
return replacements, group
111111

112-
def _prepare_search_expressions_for_pipeline(
113-
self, expression, target, search_idx, replacements
114-
):
112+
def _prepare_search_expressions_for_pipeline(self, expression, search_idx, replacements):
115113
searches = {}
116114
for sub_expr in self._get_search_expressions(expression):
117115
if sub_expr not in replacements:
118116
alias = f"__search_expr.search{next(search_idx)}"
119117
replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias)
120-
return list(searches.values())
121118

122119
def _prepare_search_query_for_aggregation_pipeline(self, order_by):
123120
replacements = {}
124-
searches = []
125121
annotation_group_idx = itertools.count(start=1)
126-
for target, expr in self.query.annotation_select.items():
127-
expr_searches = self._prepare_search_expressions_for_pipeline(
128-
expr, target, annotation_group_idx, replacements
129-
)
130-
searches += expr_searches
122+
for expr in self.query.annotation_select.values():
123+
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)
131124

132125
for expr, _ in order_by:
133-
expr_searches = self._prepare_search_expressions_for_pipeline(
134-
expr, None, annotation_group_idx, replacements
135-
)
136-
searches += expr_searches
126+
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)
137127

138-
having_group = self._prepare_search_expressions_for_pipeline(
139-
self.having, None, annotation_group_idx, replacements
128+
self._prepare_search_expressions_for_pipeline(
129+
self.having, annotation_group_idx, replacements
140130
)
141-
searches += having_group
142-
return searches, replacements
131+
return replacements
143132

144133
def _prepare_annotations_for_aggregation_pipeline(self, order_by):
145134
"""Prepare annotations for the aggregation pipeline."""
@@ -248,22 +237,36 @@ def _build_aggregation_pipeline(self, ids, group):
248237
pipeline.append({"$unset": "_id"})
249238
return pipeline
250239

251-
def _compound_searches_queries(self, searches, search_replacements):
252-
if not searches:
240+
def _compound_searches_queries(self, search_replacements):
241+
if not search_replacements:
253242
return []
254-
if len(searches) > 1:
243+
if len(search_replacements) > 1:
255244
raise ValueError("Cannot perform more than one search operation.")
256-
score_function = "searchScore" if "$search" in searches[0] else "vectorSearchScore"
257-
return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": score_function}}}]
245+
pipeline = []
246+
for search, result_col in search_replacements.items():
247+
score_function = (
248+
"vectorSearchScore" if isinstance(search, SearchVector) else "searchScore"
249+
)
250+
pipeline.extend(
251+
[
252+
search.as_mql(self, self.connection),
253+
{
254+
"$addFields": {
255+
result_col.as_mql(self, self.connection).removeprefix("$"): {
256+
"$meta": score_function
257+
}
258+
}
259+
},
260+
]
261+
)
262+
return pipeline
258263

259264
def pre_sql_setup(self, with_col_aliases=False):
260265
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
261-
searches, search_replacements = self._prepare_search_query_for_aggregation_pipeline(
262-
order_by
263-
)
266+
search_replacements = self._prepare_search_query_for_aggregation_pipeline(order_by)
264267
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
265268
all_replacements = {**search_replacements, **group_replacements}
266-
self.search_pipeline = self._compound_searches_queries(searches, search_replacements)
269+
self.search_pipeline = self._compound_searches_queries(search_replacements)
267270
# query.group_by is either:
268271
# - None: no GROUP BY
269272
# - True: group by select fields

django_mongodb_backend/creation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def _destroy_test_db(self, test_database_name, verbosity):
2222

2323
for collection in self.connection.introspection.table_names():
2424
if not collection.startswith("system."):
25+
db_collection = self.connection.database.get_collection(collection)
26+
for search_indexes in db_collection.list_search_indexes():
27+
db_collection.drop_search_index(search_indexes["name"])
2528
self.connection.database.drop_collection(collection)
2629

2730
def create_test_db(self, *args, **kwargs):
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from django.test import SimpleTestCase
2+
3+
from django_mongodb_backend.expressions.builtins import (
4+
CombinedSearchExpression,
5+
CompoundExpression,
6+
SearchEquals,
7+
)
8+
9+
10+
class CombinedSearchExpressionResolutionTest(SimpleTestCase):
11+
def test_combined_expression_and_or_not_resolution(self):
12+
A = SearchEquals(path="headline", value="A")
13+
B = SearchEquals(path="headline", value="B")
14+
C = SearchEquals(path="headline", value="C")
15+
D = SearchEquals(path="headline", value="D")
16+
expr = (~A | B) & (C | D)
17+
solved = CombinedSearchExpression.resolve(expr)
18+
self.assertIsInstance(solved, CompoundExpression)
19+
solved_A = CompoundExpression(must_not=[CompoundExpression(must=[A])])
20+
solved_B = CompoundExpression(must=[B])
21+
solved_C = CompoundExpression(must=[C])
22+
solved_D = CompoundExpression(must=[D])
23+
self.assertCountEqual(solved.must[0].should, [solved_A, solved_B])
24+
self.assertEqual(solved.must[0].minimum_should_match, 1)
25+
self.assertEqual(solved.must[1].should, [solved_C, solved_D])
26+
27+
def test_combined_expression_de_morgans_resolution(self):
28+
A = SearchEquals(path="headline", value="A")
29+
B = SearchEquals(path="headline", value="B")
30+
C = SearchEquals(path="headline", value="C")
31+
D = SearchEquals(path="headline", value="D")
32+
expr = ~(A | B) & (C | D)
33+
solved_A = CompoundExpression(must_not=[CompoundExpression(must=[A])])
34+
solved_B = CompoundExpression(must_not=[CompoundExpression(must=[B])])
35+
solved_C = CompoundExpression(must=[C])
36+
solved_D = CompoundExpression(must=[D])
37+
solved = CombinedSearchExpression.resolve(expr)
38+
self.assertIsInstance(solved, CompoundExpression)
39+
self.assertCountEqual(solved.must[0].must, [solved_A, solved_B])
40+
self.assertEqual(solved.must[0].minimum_should_match, None)
41+
self.assertEqual(solved.must[1].should, [solved_C, solved_D])
42+
self.assertEqual(solved.minimum_should_match, None)
43+
44+
def test_combined_expression_doble_negation(self):
45+
A = SearchEquals(path="headline", value="A")
46+
expr = ~~A
47+
solved = CombinedSearchExpression.resolve(expr)
48+
solved_A = CompoundExpression(must=[A])
49+
self.assertIsInstance(solved, CompoundExpression)
50+
self.assertEqual(solved, solved_A)
51+
52+
def test_combined_expression_long_right_tree(self):
53+
A = SearchEquals(path="headline", value="A")
54+
B = SearchEquals(path="headline", value="B")
55+
C = SearchEquals(path="headline", value="C")
56+
D = SearchEquals(path="headline", value="D")
57+
solved_A = CompoundExpression(must=[A])
58+
solved_B = CompoundExpression(must_not=[CompoundExpression(must=[B])])
59+
solved_C = CompoundExpression(must=[C])
60+
solved_D = CompoundExpression(must=[D])
61+
expr = A & ~(B & ~(C & D))
62+
solved = CombinedSearchExpression.resolve(expr)
63+
self.assertIsInstance(solved, CompoundExpression)
64+
self.assertEqual(len(solved.must), 2)
65+
self.assertEqual(solved.must[0], solved_A)
66+
self.assertEqual(len(solved.must[1].should), 2)
67+
self.assertEqual(solved.must[1].should[0], solved_B)
68+
self.assertCountEqual(solved.must[1].should[1].must, [solved_C, solved_D])
69+
expr = A | ~(B | ~(C | D))
70+
solved = CombinedSearchExpression.resolve(expr)
71+
self.assertIsInstance(solved, CompoundExpression)
72+
self.assertEqual(len(solved.should), 2)
73+
self.assertEqual(solved.should[0], solved_A)
74+
self.assertEqual(len(solved.should[1].must), 2)
75+
self.assertEqual(solved.should[1].must[0], solved_B)
76+
self.assertCountEqual(solved.should[1].must[1].should, [solved_C, solved_D])

tests/queries_/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,4 @@ class Article(models.Model):
6060
number = models.IntegerField()
6161
body = models.TextField()
6262
location = models.JSONField(null=True)
63-
plot_embedding = ArrayField(models.FloatField(), size=3)
63+
plot_embedding = ArrayField(models.FloatField(), size=3, null=True)

0 commit comments

Comments
 (0)