17
17
from django .utils .functional import cached_property
18
18
from pymongo import ASCENDING , DESCENDING
19
19
20
+ from .expressions .search import SearchExpression , SearchVector
20
21
from .query import MongoQuery , wrap_database_errors
21
22
from .query_utils import is_direct_value
22
23
@@ -35,6 +36,8 @@ def __init__(self, *args, **kwargs):
35
36
# A list of OrderBy objects for this query.
36
37
self .order_by_objs = None
37
38
self .subqueries = []
39
+ # Atlas search stage.
40
+ self .search_pipeline = []
38
41
39
42
def _get_group_alias_column (self , expr , annotation_group_idx ):
40
43
"""Generate a dummy field for use in the ids fields in $group."""
@@ -58,6 +61,29 @@ def _get_column_from_expression(self, expr, alias):
58
61
column_target .set_attributes_from_name (alias )
59
62
return Col (self .collection_name , column_target )
60
63
64
+ def _get_replace_expr (self , sub_expr , group , alias ):
65
+ column_target = sub_expr .output_field .clone ()
66
+ column_target .db_column = alias
67
+ column_target .set_attributes_from_name (alias )
68
+ inner_column = Col (self .collection_name , column_target )
69
+ if getattr (sub_expr , "distinct" , False ):
70
+ # If the expression should return distinct values, use $addToSet to
71
+ # deduplicate.
72
+ rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
73
+ group [alias ] = {"$addToSet" : rhs }
74
+ replacing_expr = sub_expr .copy ()
75
+ replacing_expr .set_source_expressions ([inner_column , None ])
76
+ else :
77
+ group [alias ] = sub_expr .as_mql (self , self .connection )
78
+ replacing_expr = inner_column
79
+ # Count must return 0 rather than null.
80
+ if isinstance (sub_expr , Count ):
81
+ replacing_expr = Coalesce (replacing_expr , 0 )
82
+ # Variance = StdDev^2
83
+ if isinstance (sub_expr , Variance ):
84
+ replacing_expr = Power (replacing_expr , 2 )
85
+ return replacing_expr
86
+
61
87
def _prepare_expressions_for_pipeline (self , expression , target , annotation_group_idx ):
62
88
"""
63
89
Prepare expressions for the aggregation pipeline.
@@ -81,29 +107,51 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
81
107
alias = (
82
108
f"__aggregation{ next (annotation_group_idx )} " if sub_expr != expression else target
83
109
)
84
- column_target = sub_expr .output_field .clone ()
85
- column_target .db_column = alias
86
- column_target .set_attributes_from_name (alias )
87
- inner_column = Col (self .collection_name , column_target )
88
- if sub_expr .distinct :
89
- # If the expression should return distinct values, use
90
- # $addToSet to deduplicate.
91
- rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
92
- group [alias ] = {"$addToSet" : rhs }
93
- replacing_expr = sub_expr .copy ()
94
- replacing_expr .set_source_expressions ([inner_column , None ])
95
- else :
96
- group [alias ] = sub_expr .as_mql (self , self .connection )
97
- replacing_expr = inner_column
98
- # Count must return 0 rather than null.
99
- if isinstance (sub_expr , Count ):
100
- replacing_expr = Coalesce (replacing_expr , 0 )
101
- # Variance = StdDev^2
102
- if isinstance (sub_expr , Variance ):
103
- replacing_expr = Power (replacing_expr , 2 )
104
- replacements [sub_expr ] = replacing_expr
110
+ replacements [sub_expr ] = self ._get_replace_expr (sub_expr , group , alias )
105
111
return replacements , group
106
112
113
+ def _prepare_search_expressions_for_pipeline (self , expression , search_idx , replacements ):
114
+ """
115
+ Collect and prepare unique search expressions for inclusion in an
116
+ aggregation pipeline.
117
+
118
+ Iterate over all search sub-expressions of the given expression.
119
+ Assigning a unique alias to each and map them to their replacement
120
+ expressions.
121
+ """
122
+ searches = {}
123
+ for sub_expr in self ._get_search_expressions (expression ):
124
+ if sub_expr not in replacements :
125
+ alias = f"__search_expr.search{ next (search_idx )} "
126
+ replacements [sub_expr ] = self ._get_replace_expr (sub_expr , searches , alias )
127
+
128
+ def _prepare_search_query_for_aggregation_pipeline (self , order_by ):
129
+ """
130
+ Prepare expressions for the search pipeline.
131
+
132
+ Handle the computation of search functions used by various expressions.
133
+ Separate and create intermediate columns, and replace nodes to simulate
134
+ a search operation.
135
+
136
+ To apply operations over the $search or $searchVector stages, compute
137
+ the $search or $vectorSearch first, then apply additional operations in
138
+ a subsequent stage by replacing the aggregate expressions with a new
139
+ document field prefixed by `__search_expr.search#`.
140
+ """
141
+ replacements = {}
142
+ annotation_group_idx = itertools .count (start = 1 )
143
+ for expr in self .query .annotation_select .values ():
144
+ self ._prepare_search_expressions_for_pipeline (expr , annotation_group_idx , replacements )
145
+ for expr , _ in order_by :
146
+ self ._prepare_search_expressions_for_pipeline (expr , annotation_group_idx , replacements )
147
+ self ._prepare_search_expressions_for_pipeline (
148
+ self .having , annotation_group_idx , replacements
149
+ )
150
+ self ._prepare_search_expressions_for_pipeline (
151
+ self .get_where (), annotation_group_idx , replacements
152
+ )
153
+ return replacements
154
+
107
155
def _prepare_annotations_for_aggregation_pipeline (self , order_by ):
108
156
"""Prepare annotations for the aggregation pipeline."""
109
157
replacements = {}
@@ -208,9 +256,67 @@ def _build_aggregation_pipeline(self, ids, group):
208
256
pipeline .append ({"$unset" : "_id" })
209
257
return pipeline
210
258
259
+ def _compound_searches_queries (self , search_replacements ):
260
+ """
261
+ Build a query pipeline from a mapping of search expressions to result
262
+ columns.
263
+
264
+ Currently only a single $search or $vectorSearch expression is
265
+ supported. Combining multiple search expressions raises ValueError.
266
+
267
+ This method will eventually support hybrid search by allowing the
268
+ combination of $search and $vectorSearch operations.
269
+ """
270
+ if not search_replacements :
271
+ return []
272
+ if len (search_replacements ) > 1 :
273
+ has_search = any (not isinstance (search , SearchVector ) for search in search_replacements )
274
+ has_vector_search = any (
275
+ isinstance (search , SearchVector ) for search in search_replacements
276
+ )
277
+ if has_search and has_vector_search :
278
+ raise ValueError (
279
+ "Cannot combine a `$vectorSearch` with a `$search` operator. "
280
+ "If you need to combine them, consider restructuring your query logic or "
281
+ "running them as separate queries."
282
+ )
283
+ if has_vector_search :
284
+ raise ValueError (
285
+ "Cannot combine two `$vectorSearch` operator. "
286
+ "If you need to combine them, consider restructuring your query logic or "
287
+ "running them as separate queries."
288
+ )
289
+ raise ValueError (
290
+ "Only one $search operation is allowed per query. "
291
+ f"Received { len (search_replacements )} search expressions. "
292
+ "To combine multiple search expressions, use either a CompoundExpression for "
293
+ "fine-grained control or CombinedSearchExpression for simple logical combinations."
294
+ )
295
+ pipeline = []
296
+ for search , result_col in search_replacements .items ():
297
+ score_function = (
298
+ "vectorSearchScore" if isinstance (search , SearchVector ) else "searchScore"
299
+ )
300
+ pipeline .extend (
301
+ [
302
+ search .as_mql (self , self .connection ),
303
+ {
304
+ "$addFields" : {
305
+ result_col .as_mql (self , self .connection , as_path = True ): {
306
+ "$meta" : score_function
307
+ }
308
+ }
309
+ },
310
+ ]
311
+ )
312
+ return pipeline
313
+
211
314
def pre_sql_setup (self , with_col_aliases = False ):
212
315
extra_select , order_by , group_by = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
213
- group , all_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
316
+ search_replacements = self ._prepare_search_query_for_aggregation_pipeline (order_by )
317
+ group , group_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
318
+ all_replacements = {** search_replacements , ** group_replacements }
319
+ self .search_pipeline = self ._compound_searches_queries (search_replacements )
214
320
# query.group_by is either:
215
321
# - None: no GROUP BY
216
322
# - True: group by select fields
@@ -235,6 +341,8 @@ def pre_sql_setup(self, with_col_aliases=False):
235
341
for target , expr in self .query .annotation_select .items ()
236
342
}
237
343
self .order_by_objs = [expr .replace_expressions (all_replacements ) for expr , _ in order_by ]
344
+ if (where := self .get_where ()) and search_replacements :
345
+ self .set_where (where .replace_expressions (search_replacements ))
238
346
return extra_select , order_by , group_by
239
347
240
348
def execute_sql (
@@ -573,10 +681,16 @@ def get_lookup_pipeline(self):
573
681
return result
574
682
575
683
def _get_aggregate_expressions (self , expr ):
684
+ return self ._get_all_expressions_of_type (expr , Aggregate )
685
+
686
+ def _get_search_expressions (self , expr ):
687
+ return self ._get_all_expressions_of_type (expr , SearchExpression )
688
+
689
+ def _get_all_expressions_of_type (self , expr , target_type ):
576
690
stack = [expr ]
577
691
while stack :
578
692
expr = stack .pop ()
579
- if isinstance (expr , Aggregate ):
693
+ if isinstance (expr , target_type ):
580
694
yield expr
581
695
elif hasattr (expr , "get_source_expressions" ):
582
696
stack .extend (expr .get_source_expressions ())
@@ -645,6 +759,9 @@ def _get_ordering(self):
645
759
def get_where (self ):
646
760
return getattr (self , "where" , self .query .where )
647
761
762
+ def set_where (self , value ):
763
+ self .where = value
764
+
648
765
def explain_query (self ):
649
766
# Validate format (none supported) and options.
650
767
options = self .connection .ops .explain_query_prefix (
0 commit comments