16
16
from django .utils .functional import cached_property
17
17
from pymongo import ASCENDING , DESCENDING
18
18
19
+ from .expressions .search import SearchExpression , SearchVector
19
20
from .query import MongoQuery , wrap_database_errors
20
21
21
22
@@ -33,6 +34,8 @@ def __init__(self, *args, **kwargs):
33
34
# A list of OrderBy objects for this query.
34
35
self .order_by_objs = None
35
36
self .subqueries = []
37
+ # Atlas search calls
38
+ self .search_pipeline = []
36
39
37
40
def _get_group_alias_column (self , expr , annotation_group_idx ):
38
41
"""Generate a dummy field for use in the ids fields in $group."""
@@ -56,6 +59,29 @@ def _get_column_from_expression(self, expr, alias):
56
59
column_target .set_attributes_from_name (alias )
57
60
return Col (self .collection_name , column_target )
58
61
62
+ def _get_replace_expr (self , sub_expr , group , alias ):
63
+ column_target = sub_expr .output_field .clone ()
64
+ column_target .db_column = alias
65
+ column_target .set_attributes_from_name (alias )
66
+ inner_column = Col (self .collection_name , column_target )
67
+ if getattr (sub_expr , "distinct" , False ):
68
+ # If the expression should return distinct values, use
69
+ # $addToSet to deduplicate.
70
+ rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
71
+ group [alias ] = {"$addToSet" : rhs }
72
+ replacing_expr = sub_expr .copy ()
73
+ replacing_expr .set_source_expressions ([inner_column , None ])
74
+ else :
75
+ group [alias ] = sub_expr .as_mql (self , self .connection )
76
+ replacing_expr = inner_column
77
+ # Count must return 0 rather than null.
78
+ if isinstance (sub_expr , Count ):
79
+ replacing_expr = Coalesce (replacing_expr , 0 )
80
+ # Variance = StdDev^2
81
+ if isinstance (sub_expr , Variance ):
82
+ replacing_expr = Power (replacing_expr , 2 )
83
+ return replacing_expr
84
+
59
85
def _prepare_expressions_for_pipeline (self , expression , target , annotation_group_idx ):
60
86
"""
61
87
Prepare expressions for the aggregation pipeline.
@@ -79,29 +105,33 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
79
105
alias = (
80
106
f"__aggregation{ next (annotation_group_idx )} " if sub_expr != expression else target
81
107
)
82
- column_target = sub_expr .output_field .clone ()
83
- column_target .db_column = alias
84
- column_target .set_attributes_from_name (alias )
85
- inner_column = Col (self .collection_name , column_target )
86
- if sub_expr .distinct :
87
- # If the expression should return distinct values, use
88
- # $addToSet to deduplicate.
89
- rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
90
- group [alias ] = {"$addToSet" : rhs }
91
- replacing_expr = sub_expr .copy ()
92
- replacing_expr .set_source_expressions ([inner_column , None ])
93
- else :
94
- group [alias ] = sub_expr .as_mql (self , self .connection )
95
- replacing_expr = inner_column
96
- # Count must return 0 rather than null.
97
- if isinstance (sub_expr , Count ):
98
- replacing_expr = Coalesce (replacing_expr , 0 )
99
- # Variance = StdDev^2
100
- if isinstance (sub_expr , Variance ):
101
- replacing_expr = Power (replacing_expr , 2 )
102
- replacements [sub_expr ] = replacing_expr
108
+ replacements [sub_expr ] = self ._get_replace_expr (sub_expr , group , alias )
103
109
return replacements , group
104
110
111
+ def _prepare_search_expressions_for_pipeline (self , expression , search_idx , replacements ):
112
+ searches = {}
113
+ for sub_expr in self ._get_search_expressions (expression ):
114
+ if sub_expr not in replacements :
115
+ alias = f"__search_expr.search{ next (search_idx )} "
116
+ replacements [sub_expr ] = self ._get_replace_expr (sub_expr , searches , alias )
117
+
118
+ def _prepare_search_query_for_aggregation_pipeline (self , order_by ):
119
+ replacements = {}
120
+ annotation_group_idx = itertools .count (start = 1 )
121
+ for expr in self .query .annotation_select .values ():
122
+ self ._prepare_search_expressions_for_pipeline (expr , annotation_group_idx , replacements )
123
+
124
+ for expr , _ in order_by :
125
+ self ._prepare_search_expressions_for_pipeline (expr , annotation_group_idx , replacements )
126
+
127
+ self ._prepare_search_expressions_for_pipeline (
128
+ self .having , annotation_group_idx , replacements
129
+ )
130
+ self ._prepare_search_expressions_for_pipeline (
131
+ self .get_where (), annotation_group_idx , replacements
132
+ )
133
+ return replacements
134
+
105
135
def _prepare_annotations_for_aggregation_pipeline (self , order_by ):
106
136
"""Prepare annotations for the aggregation pipeline."""
107
137
replacements = {}
@@ -206,9 +236,57 @@ def _build_aggregation_pipeline(self, ids, group):
206
236
pipeline .append ({"$unset" : "_id" })
207
237
return pipeline
208
238
239
+ def _compound_searches_queries (self , search_replacements ):
240
+ if not search_replacements :
241
+ return []
242
+ if len (search_replacements ) > 1 :
243
+ has_search = any (not isinstance (search , SearchVector ) for search in search_replacements )
244
+ has_vector_search = any (
245
+ isinstance (search , SearchVector ) for search in search_replacements
246
+ )
247
+ if has_search and has_vector_search :
248
+ raise ValueError (
249
+ "Cannot combine a `$vectorSearch` with a `$search` operator. "
250
+ "If you need to combine them, consider restructuring your query logic or "
251
+ "running them as separate queries."
252
+ )
253
+ if not has_search :
254
+ raise ValueError (
255
+ "Cannot combine two `$vectorSearch` operator. "
256
+ "If you need to combine them, consider restructuring your query logic or "
257
+ "running them as separate queries."
258
+ )
259
+ raise ValueError (
260
+ "Only one $search operation is allowed per query. "
261
+ f"Received { len (search_replacements )} search expressions. "
262
+ "To combine multiple search expressions, use either a CompoundExpression for "
263
+ "fine-grained control or CombinedSearchExpression for simple logical combinations."
264
+ )
265
+ pipeline = []
266
+ for search , result_col in search_replacements .items ():
267
+ score_function = (
268
+ "vectorSearchScore" if isinstance (search , SearchVector ) else "searchScore"
269
+ )
270
+ pipeline .extend (
271
+ [
272
+ search .as_mql (self , self .connection ),
273
+ {
274
+ "$addFields" : {
275
+ result_col .as_mql (self , self .connection , as_path = True ): {
276
+ "$meta" : score_function
277
+ }
278
+ }
279
+ },
280
+ ]
281
+ )
282
+ return pipeline
283
+
209
284
def pre_sql_setup (self , with_col_aliases = False ):
210
285
extra_select , order_by , group_by = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
211
- group , all_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
286
+ search_replacements = self ._prepare_search_query_for_aggregation_pipeline (order_by )
287
+ group , group_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
288
+ all_replacements = {** search_replacements , ** group_replacements }
289
+ self .search_pipeline = self ._compound_searches_queries (search_replacements )
212
290
# query.group_by is either:
213
291
# - None: no GROUP BY
214
292
# - True: group by select fields
@@ -233,6 +311,9 @@ def pre_sql_setup(self, with_col_aliases=False):
233
311
for target , expr in self .query .annotation_select .items ()
234
312
}
235
313
self .order_by_objs = [expr .replace_expressions (all_replacements ) for expr , _ in order_by ]
314
+ if (where := self .get_where ()) and search_replacements :
315
+ where = where .replace_expressions (search_replacements )
316
+ self .set_where (where )
236
317
return extra_select , order_by , group_by
237
318
238
319
def execute_sql (
@@ -556,10 +637,16 @@ def get_lookup_pipeline(self):
556
637
return result
557
638
558
639
def _get_aggregate_expressions (self , expr ):
640
+ return self ._get_all_expressions_of_type (expr , Aggregate )
641
+
642
+ def _get_search_expressions (self , expr ):
643
+ return self ._get_all_expressions_of_type (expr , SearchExpression )
644
+
645
+ def _get_all_expressions_of_type (self , expr , target_type ):
559
646
stack = [expr ]
560
647
while stack :
561
648
expr = stack .pop ()
562
- if isinstance (expr , Aggregate ):
649
+ if isinstance (expr , target_type ):
563
650
yield expr
564
651
elif hasattr (expr , "get_source_expressions" ):
565
652
stack .extend (expr .get_source_expressions ())
@@ -628,6 +715,9 @@ def _get_ordering(self):
628
715
def get_where (self ):
629
716
return getattr (self , "where" , self .query .where )
630
717
718
+ def set_where (self , value ):
719
+ self .where = value
720
+
631
721
def explain_query (self ):
632
722
# Validate format (none supported) and options.
633
723
options = self .connection .ops .explain_query_prefix (
0 commit comments