|
17 | 17 | from django.utils.functional import cached_property
|
18 | 18 | from pymongo import ASCENDING, DESCENDING
|
19 | 19 |
|
20 |
| -from .expressions.builtins import SearchExpression |
| 20 | +from .expressions.builtins import SearchExpression, SearchVector |
21 | 21 | from .query import MongoQuery, wrap_database_errors
|
22 | 22 |
|
23 | 23 |
|
@@ -109,37 +109,26 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
|
109 | 109 | replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias)
|
110 | 110 | return replacements, group
|
111 | 111 |
|
112 |
| - def _prepare_search_expressions_for_pipeline( |
113 |
| - self, expression, target, search_idx, replacements |
114 |
| - ): |
| 112 | + def _prepare_search_expressions_for_pipeline(self, expression, search_idx, replacements): |
115 | 113 | searches = {}
|
116 | 114 | for sub_expr in self._get_search_expressions(expression):
|
117 | 115 | if sub_expr not in replacements:
|
118 | 116 | alias = f"__search_expr.search{next(search_idx)}"
|
119 | 117 | replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias)
|
120 |
| - return list(searches.values()) |
121 | 118 |
|
122 | 119 | def _prepare_search_query_for_aggregation_pipeline(self, order_by):
|
123 | 120 | replacements = {}
|
124 |
| - searches = [] |
125 | 121 | annotation_group_idx = itertools.count(start=1)
|
126 |
| - for target, expr in self.query.annotation_select.items(): |
127 |
| - expr_searches = self._prepare_search_expressions_for_pipeline( |
128 |
| - expr, target, annotation_group_idx, replacements |
129 |
| - ) |
130 |
| - searches += expr_searches |
| 122 | + for expr in self.query.annotation_select.values(): |
| 123 | + self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements) |
131 | 124 |
|
132 | 125 | for expr, _ in order_by:
|
133 |
| - expr_searches = self._prepare_search_expressions_for_pipeline( |
134 |
| - expr, None, annotation_group_idx, replacements |
135 |
| - ) |
136 |
| - searches += expr_searches |
| 126 | + self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements) |
137 | 127 |
|
138 |
| - having_group = self._prepare_search_expressions_for_pipeline( |
139 |
| - self.having, None, annotation_group_idx, replacements |
| 128 | + self._prepare_search_expressions_for_pipeline( |
| 129 | + self.having, annotation_group_idx, replacements |
140 | 130 | )
|
141 |
| - searches += having_group |
142 |
| - return searches, replacements |
| 131 | + return replacements |
143 | 132 |
|
144 | 133 | def _prepare_annotations_for_aggregation_pipeline(self, order_by):
|
145 | 134 | """Prepare annotations for the aggregation pipeline."""
|
@@ -248,22 +237,36 @@ def _build_aggregation_pipeline(self, ids, group):
|
248 | 237 | pipeline.append({"$unset": "_id"})
|
249 | 238 | return pipeline
|
250 | 239 |
|
251 |
| - def _compound_searches_queries(self, searches, search_replacements): |
252 |
| - if not searches: |
| 240 | + def _compound_searches_queries(self, search_replacements): |
| 241 | + if not search_replacements: |
253 | 242 | return []
|
254 |
| - if len(searches) > 1: |
| 243 | + if len(search_replacements) > 1: |
255 | 244 | raise ValueError("Cannot perform more than one search operation.")
|
256 |
| - score_function = "searchScore" if "$search" in searches[0] else "vectorSearchScore" |
257 |
| - return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": score_function}}}] |
| 245 | + pipeline = [] |
| 246 | + for search, result_col in search_replacements.items(): |
| 247 | + score_function = ( |
| 248 | + "vectorSearchScore" if isinstance(search, SearchVector) else "searchScore" |
| 249 | + ) |
| 250 | + pipeline.extend( |
| 251 | + [ |
| 252 | + search.as_mql(self, self.connection), |
| 253 | + { |
| 254 | + "$addFields": { |
| 255 | + result_col.as_mql(self, self.connection).removeprefix("$"): { |
| 256 | + "$meta": score_function |
| 257 | + } |
| 258 | + } |
| 259 | + }, |
| 260 | + ] |
| 261 | + ) |
| 262 | + return pipeline |
258 | 263 |
|
259 | 264 | def pre_sql_setup(self, with_col_aliases=False):
|
260 | 265 | extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
|
261 |
| - searches, search_replacements = self._prepare_search_query_for_aggregation_pipeline( |
262 |
| - order_by |
263 |
| - ) |
| 266 | + search_replacements = self._prepare_search_query_for_aggregation_pipeline(order_by) |
264 | 267 | group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
|
265 | 268 | all_replacements = {**search_replacements, **group_replacements}
|
266 |
| - self.search_pipeline = self._compound_searches_queries(searches, search_replacements) |
| 269 | + self.search_pipeline = self._compound_searches_queries(search_replacements) |
267 | 270 | # query.group_by is either:
|
268 | 271 | # - None: no GROUP BY
|
269 | 272 | # - True: group by select fields
|
|
0 commit comments