17
17
from django .utils .functional import cached_property
18
18
from pymongo import ASCENDING , DESCENDING
19
19
20
+ from .expressions .builtins import SearchExpression
20
21
from .query import MongoQuery , wrap_database_errors
21
22
22
23
@@ -34,6 +35,8 @@ def __init__(self, *args, **kwargs):
34
35
# A list of OrderBy objects for this query.
35
36
self .order_by_objs = None
36
37
self .subqueries = []
38
+ # Atlas search calls
39
+ self .search_pipeline = []
37
40
38
41
def _get_group_alias_column (self , expr , annotation_group_idx ):
39
42
"""Generate a dummy field for use in the ids fields in $group."""
@@ -57,6 +60,29 @@ def _get_column_from_expression(self, expr, alias):
57
60
column_target .set_attributes_from_name (alias )
58
61
return Col (self .collection_name , column_target )
59
62
63
+ def _get_replace_expr (self , sub_expr , group , alias ):
64
+ column_target = sub_expr .output_field .clone ()
65
+ column_target .db_column = alias
66
+ column_target .set_attributes_from_name (alias )
67
+ inner_column = Col (self .collection_name , column_target )
68
+ if getattr (sub_expr , "distinct" , False ):
69
+ # If the expression should return distinct values, use
70
+ # $addToSet to deduplicate.
71
+ rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
72
+ group [alias ] = {"$addToSet" : rhs }
73
+ replacing_expr = sub_expr .copy ()
74
+ replacing_expr .set_source_expressions ([inner_column , None ])
75
+ else :
76
+ group [alias ] = sub_expr .as_mql (self , self .connection )
77
+ replacing_expr = inner_column
78
+ # Count must return 0 rather than null.
79
+ if isinstance (sub_expr , Count ):
80
+ replacing_expr = Coalesce (replacing_expr , 0 )
81
+ # Variance = StdDev^2
82
+ if isinstance (sub_expr , Variance ):
83
+ replacing_expr = Power (replacing_expr , 2 )
84
+ return replacing_expr
85
+
60
86
def _prepare_expressions_for_pipeline (self , expression , target , annotation_group_idx ):
61
87
"""
62
88
Prepare expressions for the aggregation pipeline.
@@ -80,29 +106,41 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
80
106
alias = (
81
107
f"__aggregation{ next (annotation_group_idx )} " if sub_expr != expression else target
82
108
)
83
- column_target = sub_expr .output_field .clone ()
84
- column_target .db_column = alias
85
- column_target .set_attributes_from_name (alias )
86
- inner_column = Col (self .collection_name , column_target )
87
- if sub_expr .distinct :
88
- # If the expression should return distinct values, use
89
- # $addToSet to deduplicate.
90
- rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
91
- group [alias ] = {"$addToSet" : rhs }
92
- replacing_expr = sub_expr .copy ()
93
- replacing_expr .set_source_expressions ([inner_column , None ])
94
- else :
95
- group [alias ] = sub_expr .as_mql (self , self .connection )
96
- replacing_expr = inner_column
97
- # Count must return 0 rather than null.
98
- if isinstance (sub_expr , Count ):
99
- replacing_expr = Coalesce (replacing_expr , 0 )
100
- # Variance = StdDev^2
101
- if isinstance (sub_expr , Variance ):
102
- replacing_expr = Power (replacing_expr , 2 )
103
- replacements [sub_expr ] = replacing_expr
109
+ replacements [sub_expr ] = self ._get_replace_expr (sub_expr , group , alias )
104
110
return replacements , group
105
111
112
+ def _prepare_search_expressions_for_pipeline (
113
+ self , expression , target , search_idx , replacements
114
+ ):
115
+ searches = {}
116
+ for sub_expr in self ._get_search_expressions (expression ):
117
+ if sub_expr not in replacements :
118
+ alias = f"__search_expr.search{ next (search_idx )} "
119
+ replacements [sub_expr ] = self ._get_replace_expr (sub_expr , searches , alias )
120
+ return list (searches .values ())
121
+
122
+ def _prepare_search_query_for_aggregation_pipeline (self , order_by ):
123
+ replacements = {}
124
+ searches = []
125
+ annotation_group_idx = itertools .count (start = 1 )
126
+ for target , expr in self .query .annotation_select .items ():
127
+ expr_searches = self ._prepare_search_expressions_for_pipeline (
128
+ expr , target , annotation_group_idx , replacements
129
+ )
130
+ searches += expr_searches
131
+
132
+ for expr , _ in order_by :
133
+ expr_searches = self ._prepare_search_expressions_for_pipeline (
134
+ expr , None , annotation_group_idx , replacements
135
+ )
136
+ searches += expr_searches
137
+
138
+ having_group = self ._prepare_search_expressions_for_pipeline (
139
+ self .having , None , annotation_group_idx , replacements
140
+ )
141
+ searches += having_group
142
+ return searches , replacements
143
+
106
144
def _prepare_annotations_for_aggregation_pipeline (self , order_by ):
107
145
"""Prepare annotations for the aggregation pipeline."""
108
146
replacements = {}
@@ -179,6 +217,9 @@ def _get_group_id_expressions(self, order_by):
179
217
ids = self .get_project_fields (tuple (columns ), force_expression = True )
180
218
return ids , replacements
181
219
220
+ def _build_search_pipeline (self , search_queries ):
221
+ pass
222
+
182
223
def _build_aggregation_pipeline (self , ids , group ):
183
224
"""Build the aggregation pipeline for grouping."""
184
225
pipeline = []
@@ -207,9 +248,21 @@ def _build_aggregation_pipeline(self, ids, group):
207
248
pipeline .append ({"$unset" : "_id" })
208
249
return pipeline
209
250
251
+ def _compound_searches_queries (self , searches ):
252
+ if not searches :
253
+ return []
254
+ if len (searches ) > 1 :
255
+ raise ValueError ("Cannot perform more than one search operation." )
256
+ return [searches [0 ], {"$addFields" : {"__search_expr.search1" : {"$meta" : "searchScore" }}}]
257
+
210
258
def pre_sql_setup (self , with_col_aliases = False ):
211
259
extra_select , order_by , group_by = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
212
- group , all_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
260
+ searches , search_replacements = self ._prepare_search_query_for_aggregation_pipeline (
261
+ order_by
262
+ )
263
+ group , group_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
264
+ all_replacements = {** search_replacements , ** group_replacements }
265
+ self .search_pipeline = self ._compound_searches_queries (searches )
213
266
# query.group_by is either:
214
267
# - None: no GROUP BY
215
268
# - True: group by select fields
@@ -557,10 +610,16 @@ def get_lookup_pipeline(self):
557
610
return result
558
611
559
612
def _get_aggregate_expressions (self , expr ):
613
+ return self ._get_all_expressions_of_type (expr , Aggregate )
614
+
615
+ def _get_search_expressions (self , expr ):
616
+ return self ._get_all_expressions_of_type (expr , SearchExpression )
617
+
618
+ def _get_all_expressions_of_type (self , expr , target_type ):
560
619
stack = [expr ]
561
620
while stack :
562
621
expr = stack .pop ()
563
- if isinstance (expr , Aggregate ):
622
+ if isinstance (expr , target_type ):
564
623
yield expr
565
624
elif hasattr (expr , "get_source_expressions" ):
566
625
stack .extend (expr .get_source_expressions ())
0 commit comments