Skip to content

Commit eec1dd7

Browse files
committed
Adapt query and compiler for operator support.
1 parent 4cea223 commit eec1dd7

File tree

6 files changed

+138
-26
lines changed

6 files changed

+138
-26
lines changed

django_mongodb_backend/compiler.py

Lines changed: 113 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from django.utils.functional import cached_property
1818
from pymongo import ASCENDING, DESCENDING
1919

20+
from .expressions.search import SearchExpression, SearchVector
2021
from .query import MongoQuery, wrap_database_errors
2122

2223

@@ -34,6 +35,8 @@ def __init__(self, *args, **kwargs):
3435
# A list of OrderBy objects for this query.
3536
self.order_by_objs = None
3637
self.subqueries = []
38+
# Atlas search calls
39+
self.search_pipeline = []
3740

3841
def _get_group_alias_column(self, expr, annotation_group_idx):
3942
"""Generate a dummy field for use in the ids fields in $group."""
@@ -57,6 +60,29 @@ def _get_column_from_expression(self, expr, alias):
5760
column_target.set_attributes_from_name(alias)
5861
return Col(self.collection_name, column_target)
5962

63+
def _get_replace_expr(self, sub_expr, group, alias):
64+
column_target = sub_expr.output_field.clone()
65+
column_target.db_column = alias
66+
column_target.set_attributes_from_name(alias)
67+
inner_column = Col(self.collection_name, column_target)
68+
if getattr(sub_expr, "distinct", False):
69+
# If the expression should return distinct values, use
70+
# $addToSet to deduplicate.
71+
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
72+
group[alias] = {"$addToSet": rhs}
73+
replacing_expr = sub_expr.copy()
74+
replacing_expr.set_source_expressions([inner_column, None])
75+
else:
76+
group[alias] = sub_expr.as_mql(self, self.connection)
77+
replacing_expr = inner_column
78+
# Count must return 0 rather than null.
79+
if isinstance(sub_expr, Count):
80+
replacing_expr = Coalesce(replacing_expr, 0)
81+
# Variance = StdDev^2
82+
if isinstance(sub_expr, Variance):
83+
replacing_expr = Power(replacing_expr, 2)
84+
return replacing_expr
85+
6086
def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx):
6187
"""
6288
Prepare expressions for the aggregation pipeline.
@@ -80,29 +106,33 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
80106
alias = (
81107
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target
82108
)
83-
column_target = sub_expr.output_field.clone()
84-
column_target.db_column = alias
85-
column_target.set_attributes_from_name(alias)
86-
inner_column = Col(self.collection_name, column_target)
87-
if sub_expr.distinct:
88-
# If the expression should return distinct values, use
89-
# $addToSet to deduplicate.
90-
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
91-
group[alias] = {"$addToSet": rhs}
92-
replacing_expr = sub_expr.copy()
93-
replacing_expr.set_source_expressions([inner_column, None])
94-
else:
95-
group[alias] = sub_expr.as_mql(self, self.connection)
96-
replacing_expr = inner_column
97-
# Count must return 0 rather than null.
98-
if isinstance(sub_expr, Count):
99-
replacing_expr = Coalesce(replacing_expr, 0)
100-
# Variance = StdDev^2
101-
if isinstance(sub_expr, Variance):
102-
replacing_expr = Power(replacing_expr, 2)
103-
replacements[sub_expr] = replacing_expr
109+
replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias)
104110
return replacements, group
105111

112+
def _prepare_search_expressions_for_pipeline(self, expression, search_idx, replacements):
113+
searches = {}
114+
for sub_expr in self._get_search_expressions(expression):
115+
if sub_expr not in replacements:
116+
alias = f"__search_expr.search{next(search_idx)}"
117+
replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias)
118+
119+
def _prepare_search_query_for_aggregation_pipeline(self, order_by):
120+
replacements = {}
121+
annotation_group_idx = itertools.count(start=1)
122+
for expr in self.query.annotation_select.values():
123+
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)
124+
125+
for expr, _ in order_by:
126+
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)
127+
128+
self._prepare_search_expressions_for_pipeline(
129+
self.having, annotation_group_idx, replacements
130+
)
131+
self._prepare_search_expressions_for_pipeline(
132+
self.get_where(), annotation_group_idx, replacements
133+
)
134+
return replacements
135+
106136
def _prepare_annotations_for_aggregation_pipeline(self, order_by):
107137
"""Prepare annotations for the aggregation pipeline."""
108138
replacements = {}
@@ -207,9 +237,57 @@ def _build_aggregation_pipeline(self, ids, group):
207237
pipeline.append({"$unset": "_id"})
208238
return pipeline
209239

240+
def _compound_searches_queries(self, search_replacements):
241+
if not search_replacements:
242+
return []
243+
if len(search_replacements) > 1:
244+
has_search = any(not isinstance(search, SearchVector) for search in search_replacements)
245+
has_vector_search = any(
246+
isinstance(search, SearchVector) for search in search_replacements
247+
)
248+
if has_search and has_vector_search:
249+
raise ValueError(
250+
"Cannot combine a `$vectorSearch` with a `$search` operator. "
251+
"If you need to combine them, consider restructuring your query logic or "
252+
"running them as separate queries."
253+
)
254+
if not has_search:
255+
raise ValueError(
256+
"Cannot combine two `$vectorSearch` operator. "
257+
"If you need to combine them, consider restructuring your query logic or "
258+
"running them as separate queries."
259+
)
260+
raise ValueError(
261+
"Only one $search operation is allowed per query. "
262+
f"Received {len(search_replacements)} search expressions. "
263+
"To combine multiple search expressions, use either a CompoundExpression for "
264+
"fine-grained control or CombinedSearchExpression for simple logical combinations."
265+
)
266+
pipeline = []
267+
for search, result_col in search_replacements.items():
268+
score_function = (
269+
"vectorSearchScore" if isinstance(search, SearchVector) else "searchScore"
270+
)
271+
pipeline.extend(
272+
[
273+
search.as_mql(self, self.connection),
274+
{
275+
"$addFields": {
276+
result_col.as_mql(self, self.connection, as_path=True): {
277+
"$meta": score_function
278+
}
279+
}
280+
},
281+
]
282+
)
283+
return pipeline
284+
210285
def pre_sql_setup(self, with_col_aliases=False):
211286
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
212-
group, all_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
287+
search_replacements = self._prepare_search_query_for_aggregation_pipeline(order_by)
288+
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
289+
all_replacements = {**search_replacements, **group_replacements}
290+
self.search_pipeline = self._compound_searches_queries(search_replacements)
213291
# query.group_by is either:
214292
# - None: no GROUP BY
215293
# - True: group by select fields
@@ -234,6 +312,9 @@ def pre_sql_setup(self, with_col_aliases=False):
234312
for target, expr in self.query.annotation_select.items()
235313
}
236314
self.order_by_objs = [expr.replace_expressions(all_replacements) for expr, _ in order_by]
315+
if (where := self.get_where()) and search_replacements:
316+
where = where.replace_expressions(search_replacements)
317+
self.set_where(where)
237318
return extra_select, order_by, group_by
238319

239320
def execute_sql(
@@ -557,10 +638,16 @@ def get_lookup_pipeline(self):
557638
return result
558639

559640
def _get_aggregate_expressions(self, expr):
641+
return self._get_all_expressions_of_type(expr, Aggregate)
642+
643+
def _get_search_expressions(self, expr):
644+
return self._get_all_expressions_of_type(expr, SearchExpression)
645+
646+
def _get_all_expressions_of_type(self, expr, target_type):
560647
stack = [expr]
561648
while stack:
562649
expr = stack.pop()
563-
if isinstance(expr, Aggregate):
650+
if isinstance(expr, target_type):
564651
yield expr
565652
elif hasattr(expr, "get_source_expressions"):
566653
stack.extend(expr.get_source_expressions())
@@ -629,6 +716,9 @@ def _get_ordering(self):
629716
def get_where(self):
630717
return getattr(self, "where", self.query.where)
631718

719+
def set_where(self, value):
720+
self.where = value
721+
632722
def explain_query(self):
633723
# Validate format (none supported) and options.
634724
options = self.connection.ops.explain_query_prefix(

django_mongodb_backend/creation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def _destroy_test_db(self, test_database_name, verbosity):
2222

2323
for collection in self.connection.introspection.table_names():
2424
if not collection.startswith("system."):
25+
if self.connection.features.supports_atlas_search:
26+
db_collection = self.connection.database.get_collection(collection)
27+
for search_indexes in db_collection.list_search_indexes():
28+
db_collection.drop_search_index(search_indexes["name"])
2529
self.connection.database.drop_collection(collection)
2630

2731
def create_test_db(self, *args, **kwargs):

django_mongodb_backend/expressions/builtins.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def case(self, compiler, connection):
5353
}
5454

5555

56-
def col(self, compiler, connection): # noqa: ARG001
56+
def col(self, compiler, connection, as_path=False): # noqa: ARG001
5757
# If the column is part of a subquery and belongs to one of the parent
5858
# queries, it will be stored for reference using $let in a $lookup stage.
5959
# If the query is built with `alias_cols=False`, treat the column as
@@ -71,7 +71,7 @@ def col(self, compiler, connection): # noqa: ARG001
7171
# Add the column's collection's alias for columns in joined collections.
7272
has_alias = self.alias and self.alias != compiler.collection_name
7373
prefix = f"{self.alias}." if has_alias else ""
74-
return f"${prefix}{self.target.column}"
74+
return f"{prefix}{self.target.column}" if as_path else f"${prefix}{self.target.column}"
7575

7676

7777
def col_pairs(self, compiler, connection):
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from django.db.models import Expression
2+
3+
4+
class SearchExpression(Expression):
5+
"""Base expression node for MongoDB Atlas `$search` stages."""
6+
7+
8+
class SearchVector(SearchExpression):
9+
"""
10+
Atlas Search expression that performs vector similarity search on embedded vectors.
11+
"""

django_mongodb_backend/fields/embedded_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,16 @@ def get_transform(self, name):
184184
f"{suggestion}"
185185
)
186186

187-
def as_mql(self, compiler, connection):
187+
def as_mql(self, compiler, connection, as_path=False):
188188
previous = self
189189
key_transforms = []
190190
while isinstance(previous, KeyTransform):
191191
key_transforms.insert(0, previous.key_name)
192192
previous = previous.lhs
193+
if as_path:
194+
mql = previous.as_mql(compiler, connection, as_path=True)
195+
mql_path = ".".join(key_transforms)
196+
return f"{mql}.{mql_path}"
193197
mql = previous.as_mql(compiler, connection)
194198
for key in key_transforms:
195199
mql = {"$getField": {"input": mql, "field": key}}

django_mongodb_backend/query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, compiler):
4949
self.lookup_pipeline = None
5050
self.project_fields = None
5151
self.aggregation_pipeline = compiler.aggregation_pipeline
52+
self.search_pipeline = compiler.search_pipeline
5253
self.extra_fields = None
5354
self.combinator_pipeline = None
5455
# $lookup stage that encapsulates the pipeline for performing a nested
@@ -81,6 +82,8 @@ def get_cursor(self):
8182

8283
def get_pipeline(self):
8384
pipeline = []
85+
if self.search_pipeline:
86+
pipeline.extend(self.search_pipeline)
8487
if self.lookup_pipeline:
8588
pipeline.extend(self.lookup_pipeline)
8689
for query in self.subqueries or ():

0 commit comments

Comments
 (0)