17
17
from django .utils .functional import cached_property
18
18
from pymongo import ASCENDING , DESCENDING
19
19
20
+ from .expressions .search import SearchExpression , SearchVector
20
21
from .query import MongoQuery , wrap_database_errors
21
22
22
23
@@ -34,6 +35,8 @@ def __init__(self, *args, **kwargs):
34
35
# A list of OrderBy objects for this query.
35
36
self .order_by_objs = None
36
37
self .subqueries = []
38
+ # Atlas search calls
39
+ self .search_pipeline = []
37
40
38
41
def _get_group_alias_column (self , expr , annotation_group_idx ):
39
42
"""Generate a dummy field for use in the ids fields in $group."""
@@ -57,6 +60,29 @@ def _get_column_from_expression(self, expr, alias):
57
60
column_target .set_attributes_from_name (alias )
58
61
return Col (self .collection_name , column_target )
59
62
63
+ def _get_replace_expr (self , sub_expr , group , alias ):
64
+ column_target = sub_expr .output_field .clone ()
65
+ column_target .db_column = alias
66
+ column_target .set_attributes_from_name (alias )
67
+ inner_column = Col (self .collection_name , column_target )
68
+ if getattr (sub_expr , "distinct" , False ):
69
+ # If the expression should return distinct values, use
70
+ # $addToSet to deduplicate.
71
+ rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
72
+ group [alias ] = {"$addToSet" : rhs }
73
+ replacing_expr = sub_expr .copy ()
74
+ replacing_expr .set_source_expressions ([inner_column , None ])
75
+ else :
76
+ group [alias ] = sub_expr .as_mql (self , self .connection )
77
+ replacing_expr = inner_column
78
+ # Count must return 0 rather than null.
79
+ if isinstance (sub_expr , Count ):
80
+ replacing_expr = Coalesce (replacing_expr , 0 )
81
+ # Variance = StdDev^2
82
+ if isinstance (sub_expr , Variance ):
83
+ replacing_expr = Power (replacing_expr , 2 )
84
+ return replacing_expr
85
+
60
86
def _prepare_expressions_for_pipeline (self , expression , target , annotation_group_idx ):
61
87
"""
62
88
Prepare expressions for the aggregation pipeline.
@@ -80,29 +106,33 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
80
106
alias = (
81
107
f"__aggregation{ next (annotation_group_idx )} " if sub_expr != expression else target
82
108
)
83
- column_target = sub_expr .output_field .clone ()
84
- column_target .db_column = alias
85
- column_target .set_attributes_from_name (alias )
86
- inner_column = Col (self .collection_name , column_target )
87
- if sub_expr .distinct :
88
- # If the expression should return distinct values, use
89
- # $addToSet to deduplicate.
90
- rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
91
- group [alias ] = {"$addToSet" : rhs }
92
- replacing_expr = sub_expr .copy ()
93
- replacing_expr .set_source_expressions ([inner_column , None ])
94
- else :
95
- group [alias ] = sub_expr .as_mql (self , self .connection )
96
- replacing_expr = inner_column
97
- # Count must return 0 rather than null.
98
- if isinstance (sub_expr , Count ):
99
- replacing_expr = Coalesce (replacing_expr , 0 )
100
- # Variance = StdDev^2
101
- if isinstance (sub_expr , Variance ):
102
- replacing_expr = Power (replacing_expr , 2 )
103
- replacements [sub_expr ] = replacing_expr
109
+ replacements [sub_expr ] = self ._get_replace_expr (sub_expr , group , alias )
104
110
return replacements , group
105
111
112
+ def _prepare_search_expressions_for_pipeline (self , expression , search_idx , replacements ):
113
+ searches = {}
114
+ for sub_expr in self ._get_search_expressions (expression ):
115
+ if sub_expr not in replacements :
116
+ alias = f"__search_expr.search{ next (search_idx )} "
117
+ replacements [sub_expr ] = self ._get_replace_expr (sub_expr , searches , alias )
118
+
119
+ def _prepare_search_query_for_aggregation_pipeline (self , order_by ):
120
+ replacements = {}
121
+ annotation_group_idx = itertools .count (start = 1 )
122
+ for expr in self .query .annotation_select .values ():
123
+ self ._prepare_search_expressions_for_pipeline (expr , annotation_group_idx , replacements )
124
+
125
+ for expr , _ in order_by :
126
+ self ._prepare_search_expressions_for_pipeline (expr , annotation_group_idx , replacements )
127
+
128
+ self ._prepare_search_expressions_for_pipeline (
129
+ self .having , annotation_group_idx , replacements
130
+ )
131
+ self ._prepare_search_expressions_for_pipeline (
132
+ self .get_where (), annotation_group_idx , replacements
133
+ )
134
+ return replacements
135
+
106
136
def _prepare_annotations_for_aggregation_pipeline (self , order_by ):
107
137
"""Prepare annotations for the aggregation pipeline."""
108
138
replacements = {}
@@ -207,9 +237,57 @@ def _build_aggregation_pipeline(self, ids, group):
207
237
pipeline .append ({"$unset" : "_id" })
208
238
return pipeline
209
239
240
+ def _compound_searches_queries (self , search_replacements ):
241
+ if not search_replacements :
242
+ return []
243
+ if len (search_replacements ) > 1 :
244
+ has_search = any (not isinstance (search , SearchVector ) for search in search_replacements )
245
+ has_vector_search = any (
246
+ isinstance (search , SearchVector ) for search in search_replacements
247
+ )
248
+ if has_search and has_vector_search :
249
+ raise ValueError (
250
+ "Cannot combine a `$vectorSearch` with a `$search` operator. "
251
+ "If you need to combine them, consider restructuring your query logic or "
252
+ "running them as separate queries."
253
+ )
254
+ if not has_search :
255
+ raise ValueError (
256
+ "Cannot combine two `$vectorSearch` operator. "
257
+ "If you need to combine them, consider restructuring your query logic or "
258
+ "running them as separate queries."
259
+ )
260
+ raise ValueError (
261
+ "Only one $search operation is allowed per query. "
262
+ f"Received { len (search_replacements )} search expressions. "
263
+ "To combine multiple search expressions, use either a CompoundExpression for "
264
+ "fine-grained control or CombinedSearchExpression for simple logical combinations."
265
+ )
266
+ pipeline = []
267
+ for search , result_col in search_replacements .items ():
268
+ score_function = (
269
+ "vectorSearchScore" if isinstance (search , SearchVector ) else "searchScore"
270
+ )
271
+ pipeline .extend (
272
+ [
273
+ search .as_mql (self , self .connection ),
274
+ {
275
+ "$addFields" : {
276
+ result_col .as_mql (self , self .connection , as_path = True ): {
277
+ "$meta" : score_function
278
+ }
279
+ }
280
+ },
281
+ ]
282
+ )
283
+ return pipeline
284
+
210
285
def pre_sql_setup (self , with_col_aliases = False ):
211
286
extra_select , order_by , group_by = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
212
- group , all_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
287
+ search_replacements = self ._prepare_search_query_for_aggregation_pipeline (order_by )
288
+ group , group_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
289
+ all_replacements = {** search_replacements , ** group_replacements }
290
+ self .search_pipeline = self ._compound_searches_queries (search_replacements )
213
291
# query.group_by is either:
214
292
# - None: no GROUP BY
215
293
# - True: group by select fields
@@ -234,6 +312,9 @@ def pre_sql_setup(self, with_col_aliases=False):
234
312
for target , expr in self .query .annotation_select .items ()
235
313
}
236
314
self .order_by_objs = [expr .replace_expressions (all_replacements ) for expr , _ in order_by ]
315
+ if (where := self .get_where ()) and search_replacements :
316
+ where = where .replace_expressions (search_replacements )
317
+ self .set_where (where )
237
318
return extra_select , order_by , group_by
238
319
239
320
def execute_sql (
@@ -557,10 +638,16 @@ def get_lookup_pipeline(self):
557
638
return result
558
639
559
640
def _get_aggregate_expressions (self , expr ):
641
+ return self ._get_all_expressions_of_type (expr , Aggregate )
642
+
643
+ def _get_search_expressions (self , expr ):
644
+ return self ._get_all_expressions_of_type (expr , SearchExpression )
645
+
646
+ def _get_all_expressions_of_type (self , expr , target_type ):
560
647
stack = [expr ]
561
648
while stack :
562
649
expr = stack .pop ()
563
- if isinstance (expr , Aggregate ):
650
+ if isinstance (expr , target_type ):
564
651
yield expr
565
652
elif hasattr (expr , "get_source_expressions" ):
566
653
stack .extend (expr .get_source_expressions ())
@@ -629,6 +716,9 @@ def _get_ordering(self):
629
716
def get_where (self ):
630
717
return getattr (self , "where" , self .query .where )
631
718
719
+ def set_where (self , value ):
720
+ self .where = value
721
+
632
722
def explain_query (self ):
633
723
# Validate format (none supported) and options.
634
724
options = self .connection .ops .explain_query_prefix (
0 commit comments