Skip to content

Commit 87c2ecd

Browse files
WaVEVtimgraham
authored andcommitted
INTPYTHON-522, INTPYTHON-524 Add support for Atlas and vector search queries
1 parent 395dd7e commit 87c2ecd

File tree

14 files changed

+2768
-27
lines changed

14 files changed

+2768
-27
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ repos:
8181
rev: "v2.2.6"
8282
hooks:
8383
- id: codespell
84-
args: ["-L", "nin"]
84+
args: ["-L", "nin", "-L", "searchin"]

django_mongodb_backend/compiler.py

Lines changed: 140 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from django.utils.functional import cached_property
1818
from pymongo import ASCENDING, DESCENDING
1919

20+
from .expressions.search import SearchExpression, SearchVector
2021
from .query import MongoQuery, wrap_database_errors
2122
from .query_utils import is_direct_value
2223

@@ -35,6 +36,8 @@ def __init__(self, *args, **kwargs):
3536
# A list of OrderBy objects for this query.
3637
self.order_by_objs = None
3738
self.subqueries = []
39+
# Atlas search stage.
40+
self.search_pipeline = []
3841

3942
def _get_group_alias_column(self, expr, annotation_group_idx):
4043
"""Generate a dummy field for use in the ids fields in $group."""
@@ -58,6 +61,29 @@ def _get_column_from_expression(self, expr, alias):
5861
column_target.set_attributes_from_name(alias)
5962
return Col(self.collection_name, column_target)
6063

64+
def _get_replace_expr(self, sub_expr, group, alias):
65+
column_target = sub_expr.output_field.clone()
66+
column_target.db_column = alias
67+
column_target.set_attributes_from_name(alias)
68+
inner_column = Col(self.collection_name, column_target)
69+
if getattr(sub_expr, "distinct", False):
70+
# If the expression should return distinct values, use $addToSet to
71+
# deduplicate.
72+
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
73+
group[alias] = {"$addToSet": rhs}
74+
replacing_expr = sub_expr.copy()
75+
replacing_expr.set_source_expressions([inner_column, None])
76+
else:
77+
group[alias] = sub_expr.as_mql(self, self.connection)
78+
replacing_expr = inner_column
79+
# Count must return 0 rather than null.
80+
if isinstance(sub_expr, Count):
81+
replacing_expr = Coalesce(replacing_expr, 0)
82+
# Variance = StdDev^2
83+
if isinstance(sub_expr, Variance):
84+
replacing_expr = Power(replacing_expr, 2)
85+
return replacing_expr
86+
6187
def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx):
6288
"""
6389
Prepare expressions for the aggregation pipeline.
@@ -81,29 +107,51 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
81107
alias = (
82108
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target
83109
)
84-
column_target = sub_expr.output_field.clone()
85-
column_target.db_column = alias
86-
column_target.set_attributes_from_name(alias)
87-
inner_column = Col(self.collection_name, column_target)
88-
if sub_expr.distinct:
89-
# If the expression should return distinct values, use
90-
# $addToSet to deduplicate.
91-
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
92-
group[alias] = {"$addToSet": rhs}
93-
replacing_expr = sub_expr.copy()
94-
replacing_expr.set_source_expressions([inner_column, None])
95-
else:
96-
group[alias] = sub_expr.as_mql(self, self.connection)
97-
replacing_expr = inner_column
98-
# Count must return 0 rather than null.
99-
if isinstance(sub_expr, Count):
100-
replacing_expr = Coalesce(replacing_expr, 0)
101-
# Variance = StdDev^2
102-
if isinstance(sub_expr, Variance):
103-
replacing_expr = Power(replacing_expr, 2)
104-
replacements[sub_expr] = replacing_expr
110+
replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias)
105111
return replacements, group
106112

113+
def _prepare_search_expressions_for_pipeline(self, expression, search_idx, replacements):
114+
"""
115+
Collect and prepare unique search expressions for inclusion in an
116+
aggregation pipeline.
117+
118+
Iterate over all search sub-expressions of the given expression.
119+
Assigning a unique alias to each and map them to their replacement
120+
expressions.
121+
"""
122+
searches = {}
123+
for sub_expr in self._get_search_expressions(expression):
124+
if sub_expr not in replacements:
125+
alias = f"__search_expr.search{next(search_idx)}"
126+
replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias)
127+
128+
def _prepare_search_query_for_aggregation_pipeline(self, order_by):
129+
"""
130+
Prepare expressions for the search pipeline.
131+
132+
Handle the computation of search functions used by various expressions.
133+
Separate and create intermediate columns, and replace nodes to simulate
134+
a search operation.
135+
136+
To apply operations over the $search or $searchVector stages, compute
137+
the $search or $vectorSearch first, then apply additional operations in
138+
a subsequent stage by replacing the aggregate expressions with a new
139+
document field prefixed by `__search_expr.search#`.
140+
"""
141+
replacements = {}
142+
annotation_group_idx = itertools.count(start=1)
143+
for expr in self.query.annotation_select.values():
144+
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)
145+
for expr, _ in order_by:
146+
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)
147+
self._prepare_search_expressions_for_pipeline(
148+
self.having, annotation_group_idx, replacements
149+
)
150+
self._prepare_search_expressions_for_pipeline(
151+
self.get_where(), annotation_group_idx, replacements
152+
)
153+
return replacements
154+
107155
def _prepare_annotations_for_aggregation_pipeline(self, order_by):
108156
"""Prepare annotations for the aggregation pipeline."""
109157
replacements = {}
@@ -208,9 +256,67 @@ def _build_aggregation_pipeline(self, ids, group):
208256
pipeline.append({"$unset": "_id"})
209257
return pipeline
210258

259+
def _compound_searches_queries(self, search_replacements):
260+
"""
261+
Build a query pipeline from a mapping of search expressions to result
262+
columns.
263+
264+
Currently only a single $search or $vectorSearch expression is
265+
supported. Combining multiple search expressions raises ValueError.
266+
267+
This method will eventually support hybrid search by allowing the
268+
combination of $search and $vectorSearch operations.
269+
"""
270+
if not search_replacements:
271+
return []
272+
if len(search_replacements) > 1:
273+
has_search = any(not isinstance(search, SearchVector) for search in search_replacements)
274+
has_vector_search = any(
275+
isinstance(search, SearchVector) for search in search_replacements
276+
)
277+
if has_search and has_vector_search:
278+
raise ValueError(
279+
"Cannot combine a `$vectorSearch` with a `$search` operator. "
280+
"If you need to combine them, consider restructuring your query logic or "
281+
"running them as separate queries."
282+
)
283+
if has_vector_search:
284+
raise ValueError(
285+
"Cannot combine two `$vectorSearch` operator. "
286+
"If you need to combine them, consider restructuring your query logic or "
287+
"running them as separate queries."
288+
)
289+
raise ValueError(
290+
"Only one $search operation is allowed per query. "
291+
f"Received {len(search_replacements)} search expressions. "
292+
"To combine multiple search expressions, use either a CompoundExpression for "
293+
"fine-grained control or CombinedSearchExpression for simple logical combinations."
294+
)
295+
pipeline = []
296+
for search, result_col in search_replacements.items():
297+
score_function = (
298+
"vectorSearchScore" if isinstance(search, SearchVector) else "searchScore"
299+
)
300+
pipeline.extend(
301+
[
302+
search.as_mql(self, self.connection),
303+
{
304+
"$addFields": {
305+
result_col.as_mql(self, self.connection, as_path=True): {
306+
"$meta": score_function
307+
}
308+
}
309+
},
310+
]
311+
)
312+
return pipeline
313+
211314
def pre_sql_setup(self, with_col_aliases=False):
212315
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
213-
group, all_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
316+
search_replacements = self._prepare_search_query_for_aggregation_pipeline(order_by)
317+
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
318+
all_replacements = {**search_replacements, **group_replacements}
319+
self.search_pipeline = self._compound_searches_queries(search_replacements)
214320
# query.group_by is either:
215321
# - None: no GROUP BY
216322
# - True: group by select fields
@@ -235,6 +341,8 @@ def pre_sql_setup(self, with_col_aliases=False):
235341
for target, expr in self.query.annotation_select.items()
236342
}
237343
self.order_by_objs = [expr.replace_expressions(all_replacements) for expr, _ in order_by]
344+
if (where := self.get_where()) and search_replacements:
345+
self.set_where(where.replace_expressions(search_replacements))
238346
return extra_select, order_by, group_by
239347

240348
def execute_sql(
@@ -573,10 +681,16 @@ def get_lookup_pipeline(self):
573681
return result
574682

575683
def _get_aggregate_expressions(self, expr):
684+
return self._get_all_expressions_of_type(expr, Aggregate)
685+
686+
def _get_search_expressions(self, expr):
687+
return self._get_all_expressions_of_type(expr, SearchExpression)
688+
689+
def _get_all_expressions_of_type(self, expr, target_type):
576690
stack = [expr]
577691
while stack:
578692
expr = stack.pop()
579-
if isinstance(expr, Aggregate):
693+
if isinstance(expr, target_type):
580694
yield expr
581695
elif hasattr(expr, "get_source_expressions"):
582696
stack.extend(expr.get_source_expressions())
@@ -645,6 +759,9 @@ def _get_ordering(self):
645759
def get_where(self):
646760
return getattr(self, "where", self.query.where)
647761

762+
def set_where(self, value):
763+
self.where = value
764+
648765
def explain_query(self):
649766
# Validate format (none supported) and options.
650767
options = self.connection.ops.explain_query_prefix(
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from .search import (
2+
CombinedSearchExpression,
3+
CompoundExpression,
4+
SearchAutocomplete,
5+
SearchEquals,
6+
SearchExists,
7+
SearchGeoShape,
8+
SearchGeoWithin,
9+
SearchIn,
10+
SearchMoreLikeThis,
11+
SearchPhrase,
12+
SearchQueryString,
13+
SearchRange,
14+
SearchRegex,
15+
SearchScoreOption,
16+
SearchText,
17+
SearchVector,
18+
SearchWildcard,
19+
)
20+
21+
__all__ = [
22+
"CombinedSearchExpression",
23+
"CompoundExpression",
24+
"SearchAutocomplete",
25+
"SearchEquals",
26+
"SearchExists",
27+
"SearchGeoShape",
28+
"SearchGeoWithin",
29+
"SearchIn",
30+
"SearchMoreLikeThis",
31+
"SearchPhrase",
32+
"SearchQueryString",
33+
"SearchRange",
34+
"SearchRegex",
35+
"SearchScoreOption",
36+
"SearchText",
37+
"SearchVector",
38+
"SearchWildcard",
39+
]

django_mongodb_backend/expressions/builtins.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def case(self, compiler, connection):
5353
}
5454

5555

56-
def col(self, compiler, connection): # noqa: ARG001
56+
def col(self, compiler, connection, as_path=False): # noqa: ARG001
5757
# If the column is part of a subquery and belongs to one of the parent
5858
# queries, it will be stored for reference using $let in a $lookup stage.
5959
# If the query is built with `alias_cols=False`, treat the column as
@@ -71,7 +71,9 @@ def col(self, compiler, connection): # noqa: ARG001
7171
# Add the column's collection's alias for columns in joined collections.
7272
has_alias = self.alias and self.alias != compiler.collection_name
7373
prefix = f"{self.alias}." if has_alias else ""
74-
return f"${prefix}{self.target.column}"
74+
if not as_path:
75+
prefix = f"${prefix}"
76+
return f"{prefix}{self.target.column}"
7577

7678

7779
def col_pairs(self, compiler, connection):

0 commit comments

Comments
 (0)