Skip to content

Commit 1874592

Browse files
committed
Refactor
1 parent 7d05c2c commit 1874592

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

django_mongodb_backend/compiler.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from django.utils.functional import cached_property
1818
from pymongo import ASCENDING, DESCENDING
1919

20-
from .expressions.builtins import SearchExpression
20+
from .expressions.builtins import SearchExpression, SearchVector
2121
from .query import MongoQuery, wrap_database_errors
2222

2323

@@ -117,29 +117,24 @@ def _prepare_search_expressions_for_pipeline(
117117
if sub_expr not in replacements:
118118
alias = f"__search_expr.search{next(search_idx)}"
119119
replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias)
120-
return list(searches.values())
121120

122121
def _prepare_search_query_for_aggregation_pipeline(self, order_by):
123122
replacements = {}
124-
searches = []
125123
annotation_group_idx = itertools.count(start=1)
126124
for target, expr in self.query.annotation_select.items():
127-
expr_searches = self._prepare_search_expressions_for_pipeline(
125+
self._prepare_search_expressions_for_pipeline(
128126
expr, target, annotation_group_idx, replacements
129127
)
130-
searches += expr_searches
131128

132129
for expr, _ in order_by:
133-
expr_searches = self._prepare_search_expressions_for_pipeline(
130+
self._prepare_search_expressions_for_pipeline(
134131
expr, None, annotation_group_idx, replacements
135132
)
136-
searches += expr_searches
137133

138-
having_group = self._prepare_search_expressions_for_pipeline(
134+
self._prepare_search_expressions_for_pipeline(
139135
self.having, None, annotation_group_idx, replacements
140136
)
141-
searches += having_group
142-
return searches, replacements
137+
return replacements
143138

144139
def _prepare_annotations_for_aggregation_pipeline(self, order_by):
145140
"""Prepare annotations for the aggregation pipeline."""
@@ -248,22 +243,36 @@ def _build_aggregation_pipeline(self, ids, group):
248243
pipeline.append({"$unset": "_id"})
249244
return pipeline
250245

251-
def _compound_searches_queries(self, searches, search_replacements):
252-
if not searches:
246+
def _compound_searches_queries(self, search_replacements):
247+
if not search_replacements:
253248
return []
254-
if len(searches) > 1:
249+
if len(search_replacements) > 1:
255250
raise ValueError("Cannot perform more than one search operation.")
256-
score_function = "searchScore" if "$search" in searches[0] else "vectorSearchScore"
257-
return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": score_function}}}]
251+
pipeline = []
252+
for search, result_col in search_replacements.items():
253+
score_function = (
254+
"vectorSearchScore" if isinstance(search, SearchVector) else "searchScore"
255+
)
256+
pipeline.extend(
257+
[
258+
search.as_mql(self, self.connection),
259+
{
260+
"$addFields": {
261+
result_col.as_mql(self, self.connection).removeprefix("$"): {
262+
"$meta": score_function
263+
}
264+
}
265+
},
266+
]
267+
)
268+
return pipeline
258269

259270
def pre_sql_setup(self, with_col_aliases=False):
260271
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
261-
searches, search_replacements = self._prepare_search_query_for_aggregation_pipeline(
262-
order_by
263-
)
272+
search_replacements = self._prepare_search_query_for_aggregation_pipeline(order_by)
264273
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
265274
all_replacements = {**search_replacements, **group_replacements}
266-
self.search_pipeline = self._compound_searches_queries(searches, search_replacements)
275+
self.search_pipeline = self._compound_searches_queries(search_replacements)
267276
# query.group_by is either:
268277
# - None: no GROUP BY
269278
# - True: group by select fields

0 commit comments

Comments
 (0)