Skip to content

INTPYTHON-522, INTPYTHON-524 Add support for Atlas and vector search queries #325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ repos:
rev: "v2.2.6"
hooks:
- id: codespell
args: ["-L", "nin"]
args: ["-L", "nin", "-L", "searchin"]
2 changes: 1 addition & 1 deletion django_mongodb_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .aggregates import register_aggregates # noqa: E402
from .checks import register_checks # noqa: E402
from .expressions import register_expressions # noqa: E402
from .expressions.builtins import register_expressions # noqa: E402
from .fields import register_fields # noqa: E402
from .functions import register_functions # noqa: E402
from .indexes import register_indexes # noqa: E402
Expand Down
163 changes: 140 additions & 23 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from django.utils.functional import cached_property
from pymongo import ASCENDING, DESCENDING

from .expressions.search import SearchExpression, SearchVector
from .query import MongoQuery, wrap_database_errors
from .query_utils import is_direct_value

Expand All @@ -35,6 +36,8 @@ def __init__(self, *args, **kwargs):
# A list of OrderBy objects for this query.
self.order_by_objs = None
self.subqueries = []
# Atlas search stage.
self.search_pipeline = []

def _get_group_alias_column(self, expr, annotation_group_idx):
"""Generate a dummy field for use in the ids fields in $group."""
Expand All @@ -58,6 +61,29 @@ def _get_column_from_expression(self, expr, alias):
column_target.set_attributes_from_name(alias)
return Col(self.collection_name, column_target)

def _get_replace_expr(self, sub_expr, group, alias):
column_target = sub_expr.output_field.clone()
column_target.db_column = alias
column_target.set_attributes_from_name(alias)
inner_column = Col(self.collection_name, column_target)
if getattr(sub_expr, "distinct", False):
# If the expression should return distinct values, use $addToSet to
# deduplicate.
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
group[alias] = {"$addToSet": rhs}
replacing_expr = sub_expr.copy()
replacing_expr.set_source_expressions([inner_column, None])
else:
group[alias] = sub_expr.as_mql(self, self.connection)
replacing_expr = inner_column
# Count must return 0 rather than null.
if isinstance(sub_expr, Count):
replacing_expr = Coalesce(replacing_expr, 0)
# Variance = StdDev^2
if isinstance(sub_expr, Variance):
replacing_expr = Power(replacing_expr, 2)
return replacing_expr

def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx):
"""
Prepare expressions for the aggregation pipeline.
Expand All @@ -81,29 +107,51 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
alias = (
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target
)
column_target = sub_expr.output_field.clone()
column_target.db_column = alias
column_target.set_attributes_from_name(alias)
inner_column = Col(self.collection_name, column_target)
if sub_expr.distinct:
# If the expression should return distinct values, use
# $addToSet to deduplicate.
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
group[alias] = {"$addToSet": rhs}
replacing_expr = sub_expr.copy()
replacing_expr.set_source_expressions([inner_column, None])
else:
group[alias] = sub_expr.as_mql(self, self.connection)
replacing_expr = inner_column
# Count must return 0 rather than null.
if isinstance(sub_expr, Count):
replacing_expr = Coalesce(replacing_expr, 0)
# Variance = StdDev^2
if isinstance(sub_expr, Variance):
replacing_expr = Power(replacing_expr, 2)
replacements[sub_expr] = replacing_expr
replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias)
return replacements, group

def _prepare_search_expressions_for_pipeline(self, expression, search_idx, replacements):
"""
Collect and prepare unique search expressions for inclusion in an
aggregation pipeline.
Iterate over all search sub-expressions of the given expression.
Assigning a unique alias to each and map them to their replacement
expressions.
"""
searches = {}
for sub_expr in self._get_search_expressions(expression):
if sub_expr not in replacements:
alias = f"__search_expr.search{next(search_idx)}"
replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias)

def _prepare_search_query_for_aggregation_pipeline(self, order_by):
"""
Prepare expressions for the search pipeline.
Handle the computation of search functions used by various expressions.
Separate and create intermediate columns, and replace nodes to simulate
a search operation.
To apply operations over the $search or $searchVector stages, compute
the $search or $vectorSearch first, then apply additional operations in
a subsequent stage by replacing the aggregate expressions with a new
document field prefixed by `__search_expr.search#`.
"""
replacements = {}
annotation_group_idx = itertools.count(start=1)
for expr in self.query.annotation_select.values():
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)
for expr, _ in order_by:
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)
self._prepare_search_expressions_for_pipeline(
self.having, annotation_group_idx, replacements
)
self._prepare_search_expressions_for_pipeline(
self.get_where(), annotation_group_idx, replacements
)
return replacements

def _prepare_annotations_for_aggregation_pipeline(self, order_by):
"""Prepare annotations for the aggregation pipeline."""
replacements = {}
Expand Down Expand Up @@ -208,9 +256,67 @@ def _build_aggregation_pipeline(self, ids, group):
pipeline.append({"$unset": "_id"})
return pipeline

def _compound_searches_queries(self, search_replacements):
"""
Build a query pipeline from a mapping of search expressions to result
columns.
Currently only a single $search or $vectorSearch expression is
supported. Combining multiple search expressions raises ValueError.
This method will eventually support hybrid search by allowing the
combination of $search and $vectorSearch operations.
"""
if not search_replacements:
return []
if len(search_replacements) > 1:
has_search = any(not isinstance(search, SearchVector) for search in search_replacements)
has_vector_search = any(
isinstance(search, SearchVector) for search in search_replacements
)
if has_search and has_vector_search:
raise ValueError(
"Cannot combine a `$vectorSearch` with a `$search` operator. "
"If you need to combine them, consider restructuring your query logic or "
"running them as separate queries."
)
if has_vector_search:
raise ValueError(
"Cannot combine two `$vectorSearch` operator. "
"If you need to combine them, consider restructuring your query logic or "
"running them as separate queries."
)
raise ValueError(
"Only one $search operation is allowed per query. "
f"Received {len(search_replacements)} search expressions. "
"To combine multiple search expressions, use either a CompoundExpression for "
"fine-grained control or CombinedSearchExpression for simple logical combinations."
)
pipeline = []
for search, result_col in search_replacements.items():
score_function = (
"vectorSearchScore" if isinstance(search, SearchVector) else "searchScore"
)
pipeline.extend(
[
search.as_mql(self, self.connection),
{
"$addFields": {
result_col.as_mql(self, self.connection, as_path=True): {
"$meta": score_function
}
}
},
]
)
return pipeline

def pre_sql_setup(self, with_col_aliases=False):
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
group, all_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
search_replacements = self._prepare_search_query_for_aggregation_pipeline(order_by)
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
all_replacements = {**search_replacements, **group_replacements}
self.search_pipeline = self._compound_searches_queries(search_replacements)
# query.group_by is either:
# - None: no GROUP BY
# - True: group by select fields
Expand All @@ -235,6 +341,8 @@ def pre_sql_setup(self, with_col_aliases=False):
for target, expr in self.query.annotation_select.items()
}
self.order_by_objs = [expr.replace_expressions(all_replacements) for expr, _ in order_by]
if (where := self.get_where()) and search_replacements:
self.set_where(where.replace_expressions(search_replacements))
return extra_select, order_by, group_by

def execute_sql(
Expand Down Expand Up @@ -573,10 +681,16 @@ def get_lookup_pipeline(self):
return result

def _get_aggregate_expressions(self, expr):
return self._get_all_expressions_of_type(expr, Aggregate)

def _get_search_expressions(self, expr):
return self._get_all_expressions_of_type(expr, SearchExpression)

def _get_all_expressions_of_type(self, expr, target_type):
stack = [expr]
while stack:
expr = stack.pop()
if isinstance(expr, Aggregate):
if isinstance(expr, target_type):
yield expr
elif hasattr(expr, "get_source_expressions"):
stack.extend(expr.get_source_expressions())
Expand Down Expand Up @@ -645,6 +759,9 @@ def _get_ordering(self):
def get_where(self):
return getattr(self, "where", self.query.where)

def set_where(self, value):
self.where = value

def explain_query(self):
# Validate format (none supported) and options.
options = self.connection.ops.explain_query_prefix(
Expand Down
39 changes: 39 additions & 0 deletions django_mongodb_backend/expressions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from .search import (
CombinedSearchExpression,
CompoundExpression,
SearchAutocomplete,
SearchEquals,
SearchExists,
SearchGeoShape,
SearchGeoWithin,
SearchIn,
SearchMoreLikeThis,
SearchPhrase,
SearchQueryString,
SearchRange,
SearchRegex,
SearchScoreOption,
SearchText,
SearchVector,
SearchWildcard,
)

__all__ = [
"CombinedSearchExpression",
"CompoundExpression",
"SearchAutocomplete",
"SearchEquals",
"SearchExists",
"SearchGeoShape",
"SearchGeoWithin",
"SearchIn",
"SearchMoreLikeThis",
"SearchPhrase",
"SearchQueryString",
"SearchRange",
"SearchRegex",
"SearchScoreOption",
"SearchText",
"SearchVector",
"SearchWildcard",
]
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from django.db.models.sql import Query

from .query_utils import process_lhs
from ..query_utils import process_lhs


def case(self, compiler, connection):
Expand Down Expand Up @@ -53,7 +53,7 @@ def case(self, compiler, connection):
}


def col(self, compiler, connection): # noqa: ARG001
def col(self, compiler, connection, as_path=False): # noqa: ARG001
# If the column is part of a subquery and belongs to one of the parent
# queries, it will be stored for reference using $let in a $lookup stage.
# If the query is built with `alias_cols=False`, treat the column as
Expand All @@ -71,7 +71,9 @@ def col(self, compiler, connection): # noqa: ARG001
# Add the column's collection's alias for columns in joined collections.
has_alias = self.alias and self.alias != compiler.collection_name
prefix = f"{self.alias}." if has_alias else ""
return f"${prefix}{self.target.column}"
if not as_path:
prefix = f"${prefix}"
return f"{prefix}{self.target.column}"


def col_pairs(self, compiler, connection):
Expand Down
Loading