Skip to content

Commit 310a8cc

Browse files
committed
Adapt query and compiler for operator support.
1 parent 60a6c26 commit 310a8cc

File tree

6 files changed

+138
-26
lines changed

6 files changed

+138
-26
lines changed

django_mongodb_backend/compiler.py

Lines changed: 113 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from django.utils.functional import cached_property
1717
from pymongo import ASCENDING, DESCENDING
1818

19+
from .expressions.search import SearchExpression, SearchVector
1920
from .query import MongoQuery, wrap_database_errors
2021

2122

@@ -33,6 +34,8 @@ def __init__(self, *args, **kwargs):
3334
# A list of OrderBy objects for this query.
3435
self.order_by_objs = None
3536
self.subqueries = []
37+
# Atlas search calls
38+
self.search_pipeline = []
3639

3740
def _get_group_alias_column(self, expr, annotation_group_idx):
3841
"""Generate a dummy field for use in the ids fields in $group."""
@@ -56,6 +59,29 @@ def _get_column_from_expression(self, expr, alias):
5659
column_target.set_attributes_from_name(alias)
5760
return Col(self.collection_name, column_target)
5861

62+
def _get_replace_expr(self, sub_expr, group, alias):
63+
column_target = sub_expr.output_field.clone()
64+
column_target.db_column = alias
65+
column_target.set_attributes_from_name(alias)
66+
inner_column = Col(self.collection_name, column_target)
67+
if getattr(sub_expr, "distinct", False):
68+
# If the expression should return distinct values, use
69+
# $addToSet to deduplicate.
70+
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
71+
group[alias] = {"$addToSet": rhs}
72+
replacing_expr = sub_expr.copy()
73+
replacing_expr.set_source_expressions([inner_column, None])
74+
else:
75+
group[alias] = sub_expr.as_mql(self, self.connection)
76+
replacing_expr = inner_column
77+
# Count must return 0 rather than null.
78+
if isinstance(sub_expr, Count):
79+
replacing_expr = Coalesce(replacing_expr, 0)
80+
# Variance = StdDev^2
81+
if isinstance(sub_expr, Variance):
82+
replacing_expr = Power(replacing_expr, 2)
83+
return replacing_expr
84+
5985
def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx):
6086
"""
6187
Prepare expressions for the aggregation pipeline.
@@ -79,29 +105,33 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
79105
alias = (
80106
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target
81107
)
82-
column_target = sub_expr.output_field.clone()
83-
column_target.db_column = alias
84-
column_target.set_attributes_from_name(alias)
85-
inner_column = Col(self.collection_name, column_target)
86-
if sub_expr.distinct:
87-
# If the expression should return distinct values, use
88-
# $addToSet to deduplicate.
89-
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
90-
group[alias] = {"$addToSet": rhs}
91-
replacing_expr = sub_expr.copy()
92-
replacing_expr.set_source_expressions([inner_column, None])
93-
else:
94-
group[alias] = sub_expr.as_mql(self, self.connection)
95-
replacing_expr = inner_column
96-
# Count must return 0 rather than null.
97-
if isinstance(sub_expr, Count):
98-
replacing_expr = Coalesce(replacing_expr, 0)
99-
# Variance = StdDev^2
100-
if isinstance(sub_expr, Variance):
101-
replacing_expr = Power(replacing_expr, 2)
102-
replacements[sub_expr] = replacing_expr
108+
replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias)
103109
return replacements, group
104110

111+
def _prepare_search_expressions_for_pipeline(self, expression, search_idx, replacements):
112+
searches = {}
113+
for sub_expr in self._get_search_expressions(expression):
114+
if sub_expr not in replacements:
115+
alias = f"__search_expr.search{next(search_idx)}"
116+
replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias)
117+
118+
def _prepare_search_query_for_aggregation_pipeline(self, order_by):
119+
replacements = {}
120+
annotation_group_idx = itertools.count(start=1)
121+
for expr in self.query.annotation_select.values():
122+
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)
123+
124+
for expr, _ in order_by:
125+
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)
126+
127+
self._prepare_search_expressions_for_pipeline(
128+
self.having, annotation_group_idx, replacements
129+
)
130+
self._prepare_search_expressions_for_pipeline(
131+
self.get_where(), annotation_group_idx, replacements
132+
)
133+
return replacements
134+
105135
def _prepare_annotations_for_aggregation_pipeline(self, order_by):
106136
"""Prepare annotations for the aggregation pipeline."""
107137
replacements = {}
@@ -206,9 +236,57 @@ def _build_aggregation_pipeline(self, ids, group):
206236
pipeline.append({"$unset": "_id"})
207237
return pipeline
208238

239+
def _compound_searches_queries(self, search_replacements):
240+
if not search_replacements:
241+
return []
242+
if len(search_replacements) > 1:
243+
has_search = any(not isinstance(search, SearchVector) for search in search_replacements)
244+
has_vector_search = any(
245+
isinstance(search, SearchVector) for search in search_replacements
246+
)
247+
if has_search and has_vector_search:
248+
raise ValueError(
249+
"Cannot combine a `$vectorSearch` with a `$search` operator. "
250+
"If you need to combine them, consider restructuring your query logic or "
251+
"running them as separate queries."
252+
)
253+
if not has_search:
254+
raise ValueError(
255+
"Cannot combine two `$vectorSearch` operator. "
256+
"If you need to combine them, consider restructuring your query logic or "
257+
"running them as separate queries."
258+
)
259+
raise ValueError(
260+
"Only one $search operation is allowed per query. "
261+
f"Received {len(search_replacements)} search expressions. "
262+
"To combine multiple search expressions, use either a CompoundExpression for "
263+
"fine-grained control or CombinedSearchExpression for simple logical combinations."
264+
)
265+
pipeline = []
266+
for search, result_col in search_replacements.items():
267+
score_function = (
268+
"vectorSearchScore" if isinstance(search, SearchVector) else "searchScore"
269+
)
270+
pipeline.extend(
271+
[
272+
search.as_mql(self, self.connection),
273+
{
274+
"$addFields": {
275+
result_col.as_mql(self, self.connection, as_path=True): {
276+
"$meta": score_function
277+
}
278+
}
279+
},
280+
]
281+
)
282+
return pipeline
283+
209284
def pre_sql_setup(self, with_col_aliases=False):
210285
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
211-
group, all_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
286+
search_replacements = self._prepare_search_query_for_aggregation_pipeline(order_by)
287+
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
288+
all_replacements = {**search_replacements, **group_replacements}
289+
self.search_pipeline = self._compound_searches_queries(search_replacements)
212290
# query.group_by is either:
213291
# - None: no GROUP BY
214292
# - True: group by select fields
@@ -233,6 +311,9 @@ def pre_sql_setup(self, with_col_aliases=False):
233311
for target, expr in self.query.annotation_select.items()
234312
}
235313
self.order_by_objs = [expr.replace_expressions(all_replacements) for expr, _ in order_by]
314+
if (where := self.get_where()) and search_replacements:
315+
where = where.replace_expressions(search_replacements)
316+
self.set_where(where)
236317
return extra_select, order_by, group_by
237318

238319
def execute_sql(
@@ -556,10 +637,16 @@ def get_lookup_pipeline(self):
556637
return result
557638

558639
def _get_aggregate_expressions(self, expr):
640+
return self._get_all_expressions_of_type(expr, Aggregate)
641+
642+
def _get_search_expressions(self, expr):
643+
return self._get_all_expressions_of_type(expr, SearchExpression)
644+
645+
def _get_all_expressions_of_type(self, expr, target_type):
559646
stack = [expr]
560647
while stack:
561648
expr = stack.pop()
562-
if isinstance(expr, Aggregate):
649+
if isinstance(expr, target_type):
563650
yield expr
564651
elif hasattr(expr, "get_source_expressions"):
565652
stack.extend(expr.get_source_expressions())
@@ -628,6 +715,9 @@ def _get_ordering(self):
628715
def get_where(self):
629716
return getattr(self, "where", self.query.where)
630717

718+
def set_where(self, value):
719+
self.where = value
720+
631721
def explain_query(self):
632722
# Validate format (none supported) and options.
633723
options = self.connection.ops.explain_query_prefix(

django_mongodb_backend/creation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def _destroy_test_db(self, test_database_name, verbosity):
2222

2323
for collection in self.connection.introspection.table_names():
2424
if not collection.startswith("system."):
25+
if self.connection.features.supports_atlas_search:
26+
db_collection = self.connection.database.get_collection(collection)
27+
for search_indexes in db_collection.list_search_indexes():
28+
db_collection.drop_search_index(search_indexes["name"])
2529
self.connection.database.drop_collection(collection)
2630

2731
def create_test_db(self, *args, **kwargs):

django_mongodb_backend/expressions/builtins.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def case(self, compiler, connection):
5353
}
5454

5555

56-
def col(self, compiler, connection): # noqa: ARG001
56+
def col(self, compiler, connection, as_path=False): # noqa: ARG001
5757
# If the column is part of a subquery and belongs to one of the parent
5858
# queries, it will be stored for reference using $let in a $lookup stage.
5959
# If the query is built with `alias_cols=False`, treat the column as
@@ -71,7 +71,7 @@ def col(self, compiler, connection): # noqa: ARG001
7171
# Add the column's collection's alias for columns in joined collections.
7272
has_alias = self.alias and self.alias != compiler.collection_name
7373
prefix = f"{self.alias}." if has_alias else ""
74-
return f"${prefix}{self.target.column}"
74+
return f"{prefix}{self.target.column}" if as_path else f"${prefix}{self.target.column}"
7575

7676

7777
def col_pairs(self, compiler, connection):
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from django.db.models import Expression
2+
3+
4+
class SearchExpression(Expression):
5+
"""Base expression node for MongoDB Atlas `$search` stages."""
6+
7+
8+
class SearchVector(SearchExpression):
9+
"""
10+
Atlas Search expression that performs vector similarity search on embedded vectors.
11+
"""

django_mongodb_backend/fields/embedded_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,16 @@ def get_transform(self, name):
184184
f"{suggestion}"
185185
)
186186

187-
def as_mql(self, compiler, connection):
187+
def as_mql(self, compiler, connection, as_path=False):
188188
previous = self
189189
key_transforms = []
190190
while isinstance(previous, KeyTransform):
191191
key_transforms.insert(0, previous.key_name)
192192
previous = previous.lhs
193+
if as_path:
194+
mql = previous.as_mql(compiler, connection, as_path=True)
195+
mql_path = ".".join(key_transforms)
196+
return f"{mql}.{mql_path}"
193197
mql = previous.as_mql(compiler, connection)
194198
for key in key_transforms:
195199
mql = {"$getField": {"input": mql, "field": key}}

django_mongodb_backend/query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, compiler):
4949
self.lookup_pipeline = None
5050
self.project_fields = None
5151
self.aggregation_pipeline = compiler.aggregation_pipeline
52+
self.search_pipeline = compiler.search_pipeline
5253
self.extra_fields = None
5354
self.combinator_pipeline = None
5455
# $lookup stage that encapsulates the pipeline for performing a nested
@@ -81,6 +82,8 @@ def get_cursor(self):
8182

8283
def get_pipeline(self):
8384
pipeline = []
85+
if self.search_pipeline:
86+
pipeline.extend(self.search_pipeline)
8487
if self.lookup_pipeline:
8588
pipeline.extend(self.lookup_pipeline)
8689
for query in self.subqueries or ():

0 commit comments

Comments
 (0)