Skip to content

Commit 0e1530b

Browse files
committed
Add SearchExpressions
1 parent f3185a8 commit 0e1530b

File tree

7 files changed

+734
-25
lines changed

7 files changed

+734
-25
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ repos:
8181
rev: "v2.2.6"
8282
hooks:
8383
- id: codespell
84-
args: ["-L", "nin"]
84+
args: ["-L", "nin", "-L", "searchin"]

django_mongodb_backend/compiler.py

Lines changed: 82 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from django.utils.functional import cached_property
1818
from pymongo import ASCENDING, DESCENDING
1919

20+
from .expressions.builtins import SearchExpression
2021
from .query import MongoQuery, wrap_database_errors
2122

2223

@@ -34,6 +35,8 @@ def __init__(self, *args, **kwargs):
3435
# A list of OrderBy objects for this query.
3536
self.order_by_objs = None
3637
self.subqueries = []
38+
# Atlas search calls
39+
self.search_pipeline = []
3740

3841
def _get_group_alias_column(self, expr, annotation_group_idx):
3942
"""Generate a dummy field for use in the ids fields in $group."""
@@ -57,6 +60,29 @@ def _get_column_from_expression(self, expr, alias):
5760
column_target.set_attributes_from_name(alias)
5861
return Col(self.collection_name, column_target)
5962

63+
def _get_replace_expr(self, sub_expr, group, alias):
64+
column_target = sub_expr.output_field.clone()
65+
column_target.db_column = alias
66+
column_target.set_attributes_from_name(alias)
67+
inner_column = Col(self.collection_name, column_target)
68+
if getattr(sub_expr, "distinct", False):
69+
# If the expression should return distinct values, use
70+
# $addToSet to deduplicate.
71+
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
72+
group[alias] = {"$addToSet": rhs}
73+
replacing_expr = sub_expr.copy()
74+
replacing_expr.set_source_expressions([inner_column, None])
75+
else:
76+
group[alias] = sub_expr.as_mql(self, self.connection)
77+
replacing_expr = inner_column
78+
# Count must return 0 rather than null.
79+
if isinstance(sub_expr, Count):
80+
replacing_expr = Coalesce(replacing_expr, 0)
81+
# Variance = StdDev^2
82+
if isinstance(sub_expr, Variance):
83+
replacing_expr = Power(replacing_expr, 2)
84+
return replacing_expr
85+
6086
def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx):
6187
"""
6288
Prepare expressions for the aggregation pipeline.
@@ -80,29 +106,41 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
80106
alias = (
81107
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target
82108
)
83-
column_target = sub_expr.output_field.clone()
84-
column_target.db_column = alias
85-
column_target.set_attributes_from_name(alias)
86-
inner_column = Col(self.collection_name, column_target)
87-
if sub_expr.distinct:
88-
# If the expression should return distinct values, use
89-
# $addToSet to deduplicate.
90-
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
91-
group[alias] = {"$addToSet": rhs}
92-
replacing_expr = sub_expr.copy()
93-
replacing_expr.set_source_expressions([inner_column, None])
94-
else:
95-
group[alias] = sub_expr.as_mql(self, self.connection)
96-
replacing_expr = inner_column
97-
# Count must return 0 rather than null.
98-
if isinstance(sub_expr, Count):
99-
replacing_expr = Coalesce(replacing_expr, 0)
100-
# Variance = StdDev^2
101-
if isinstance(sub_expr, Variance):
102-
replacing_expr = Power(replacing_expr, 2)
103-
replacements[sub_expr] = replacing_expr
109+
replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias)
104110
return replacements, group
105111

112+
def _prepare_search_expressions_for_pipeline(
113+
self, expression, target, search_idx, replacements
114+
):
115+
searches = {}
116+
for sub_expr in self._get_search_expressions(expression):
117+
if sub_expr not in replacements:
118+
alias = f"__search_expr.search{next(search_idx)}"
119+
replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias)
120+
return list(searches.values())
121+
122+
def _prepare_search_query_for_aggregation_pipeline(self, order_by):
123+
replacements = {}
124+
searches = []
125+
annotation_group_idx = itertools.count(start=1)
126+
for target, expr in self.query.annotation_select.items():
127+
expr_searches = self._prepare_search_expressions_for_pipeline(
128+
expr, target, annotation_group_idx, replacements
129+
)
130+
searches += expr_searches
131+
132+
for expr, _ in order_by:
133+
expr_searches = self._prepare_search_expressions_for_pipeline(
134+
expr, None, annotation_group_idx, replacements
135+
)
136+
searches += expr_searches
137+
138+
having_group = self._prepare_search_expressions_for_pipeline(
139+
self.having, None, annotation_group_idx, replacements
140+
)
141+
searches += having_group
142+
return searches, replacements
143+
106144
def _prepare_annotations_for_aggregation_pipeline(self, order_by):
107145
"""Prepare annotations for the aggregation pipeline."""
108146
replacements = {}
@@ -179,6 +217,9 @@ def _get_group_id_expressions(self, order_by):
179217
ids = self.get_project_fields(tuple(columns), force_expression=True)
180218
return ids, replacements
181219

220+
def _build_search_pipeline(self, search_queries):
221+
pass
222+
182223
def _build_aggregation_pipeline(self, ids, group):
183224
"""Build the aggregation pipeline for grouping."""
184225
pipeline = []
@@ -207,9 +248,21 @@ def _build_aggregation_pipeline(self, ids, group):
207248
pipeline.append({"$unset": "_id"})
208249
return pipeline
209250

251+
def _compound_searches_queries(self, searches):
252+
if not searches:
253+
return []
254+
if len(searches) > 1:
255+
raise ValueError("Cannot perform more than one search operation.")
256+
return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": "searchScore"}}}]
257+
210258
def pre_sql_setup(self, with_col_aliases=False):
211259
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
212-
group, all_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
260+
searches, search_replacements = self._prepare_search_query_for_aggregation_pipeline(
261+
order_by
262+
)
263+
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
264+
all_replacements = {**search_replacements, **group_replacements}
265+
self.search_pipeline = self._compound_searches_queries(searches)
213266
# query.group_by is either:
214267
# - None: no GROUP BY
215268
# - True: group by select fields
@@ -557,10 +610,16 @@ def get_lookup_pipeline(self):
557610
return result
558611

559612
def _get_aggregate_expressions(self, expr):
613+
return self._get_all_expressions_of_type(expr, Aggregate)
614+
615+
def _get_search_expressions(self, expr):
616+
return self._get_all_expressions_of_type(expr, SearchExpression)
617+
618+
def _get_all_expressions_of_type(self, expr, target_type):
560619
stack = [expr]
561620
while stack:
562621
expr = stack.pop()
563-
if isinstance(expr, Aggregate):
622+
if isinstance(expr, target_type):
564623
yield expr
565624
elif hasattr(expr, "get_source_expressions"):
566625
stack.extend(expr.get_source_expressions())

0 commit comments

Comments
 (0)