Skip to content

add support for subqueries (Subquery, Exists, and QuerySet as a lookup value) #149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,6 @@ Congratulations, your project is ready to go!
- `QuerySet.delete()` and `update()` do not support queries that span multiple
collections.

- `Subquery`, `Exists`, and using a `QuerySet` in `QuerySet.annotate()` aren't
supported.

- `DateTimeField` doesn't support microsecond precision, and correspondingly,
`DurationField` stores milliseconds rather than microseconds.

Expand Down
76 changes: 53 additions & 23 deletions django_mongodb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from django.db.models.lookups import IsNull
from django.db.models.sql import compiler
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
from django.db.models.sql.datastructures import BaseTable
from django.utils.functional import cached_property
from pymongo import ASCENDING, DESCENDING

Expand All @@ -25,12 +26,16 @@ class SQLCompiler(compiler.SQLCompiler):

query_class = MongoQuery
GROUP_SEPARATOR = "___"
PARENT_FIELD_TEMPLATE = "parent__field__{}"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.aggregation_pipeline = None
# Map columns to their subquery indices.
self.column_indices = {}
# A list of OrderBy objects for this query.
self.order_by_objs = None
self.subqueries = []

def _unfold_column(self, col):
"""
Expand Down Expand Up @@ -154,23 +159,40 @@ def _prepare_annotations_for_aggregation_pipeline(self, order_by):
group.update(having_group)
return group, replacements

def _get_group_id_expressions(self, order_by):
"""Generate group ID expressions for the aggregation pipeline."""
group_expressions = set()
replacements = {}
def _get_group_expressions(self, order_by):
if self.query.group_by is None:
return []
seen = set()
expressions = set()
if self.query.group_by is not True:
# If group_by isn't True, then it's a list of expressions.
for expr in self.query.group_by:
if not hasattr(expr, "as_sql"):
expr = self.query.resolve_ref(expr)
if isinstance(expr, Ref):
if expr.refs not in seen:
seen.add(expr.refs)
expressions.add(expr.source)
else:
expressions.add(expr)
for expr, _, alias in self.select:
# Skip members that are already grouped.
if alias not in seen:
expressions |= set(expr.get_group_by_cols())
if not self._meta_ordering:
for expr, (_, _, is_ref) in order_by:
# Skip references.
if not is_ref:
group_expressions |= set(expr.get_group_by_cols())
for expr, *_ in self.select:
group_expressions |= set(expr.get_group_by_cols())
expressions |= set(expr.get_group_by_cols())
having_group_by = self.having.get_group_by_cols() if self.having else ()
for expr in having_group_by:
group_expressions.add(expr)
if isinstance(self.query.group_by, tuple | list):
group_expressions |= set(self.query.group_by)
elif self.query.group_by is None:
group_expressions = set()
expressions.add(expr)
return expressions

def _get_group_id_expressions(self, order_by):
"""Generate group ID expressions for the aggregation pipeline."""
replacements = {}
group_expressions = self._get_group_expressions(order_by)
if not group_expressions:
ids = None
else:
Expand All @@ -186,6 +208,8 @@ def _get_group_id_expressions(self, order_by):
ids[alias] = Value(True).as_mql(self, self.connection)
if replacement is not None:
replacements[col] = replacement
if isinstance(col, Ref):
replacements[col.source] = replacement
return ids, replacements

def _build_aggregation_pipeline(self, ids, group):
Expand Down Expand Up @@ -228,15 +252,15 @@ def pre_sql_setup(self, with_col_aliases=False):
all_replacements.update(replacements)
pipeline = self._build_aggregation_pipeline(ids, group)
if self.having:
pipeline.append(
{
"$match": {
"$expr": self.having.replace_expressions(all_replacements).as_mql(
self, self.connection
)
}
}
having = self.having.replace_expressions(all_replacements).as_mql(
self, self.connection
)
# Add HAVING subqueries.
for query in self.subqueries or ():
pipeline.extend(query.get_pipeline())
# Remove the added subqueries.
self.subqueries = []
pipeline.append({"$match": {"$expr": having}})
self.aggregation_pipeline = pipeline
self.annotations = {
target: expr.replace_expressions(all_replacements)
Expand Down Expand Up @@ -388,6 +412,7 @@ def build_query(self, columns=None):
query.mongo_query = {"$expr": expr}
if extra_fields:
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
query.subqueries = self.subqueries
return query

def get_columns(self):
Expand Down Expand Up @@ -431,7 +456,12 @@ def project_field(column):

@cached_property
def collection_name(self):
return self.query.get_meta().db_table
base_table = next(
v
for k, v in self.query.alias_map.items()
if isinstance(v, BaseTable) and self.query.alias_refcount[k]
)
return base_table.table_alias or base_table.table_name

@cached_property
def collection(self):
Expand Down Expand Up @@ -581,7 +611,7 @@ def _get_ordering(self):
return tuple(fields), sort_ordering, tuple(extra_fields)

def get_where(self):
return self.where
return getattr(self, "where", self.query.where)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we are back to getattr() for this... maybe we should drop the get_where() hook we previously added to try to avoid it. Those could be done in a follow up though. For now, perhaps a comment is warranted.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey.


def explain_query(self):
# Validate format (none supported) and options.
Expand Down Expand Up @@ -741,7 +771,7 @@ def build_query(self, columns=None):
else None
)
subquery = compiler.build_query(columns)
query.subquery = subquery
query.subqueries = [subquery]
return query

def _make_result(self, result, columns=None):
Expand Down
95 changes: 91 additions & 4 deletions django_mongodb/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Case,
Col,
CombinedExpression,
Exists,
ExpressionWrapper,
F,
NegatedExpression,
Expand Down Expand Up @@ -50,6 +51,18 @@ def case(self, compiler, connection):


def col(self, compiler, connection): # noqa: ARG001
# If the column is part of a subquery and belongs to one of the parent
# queries, it will be stored for reference using $let in a $lookup stage.
if (
self.alias not in compiler.query.alias_refcount
or compiler.query.alias_refcount[self.alias] == 0
):
try:
index = compiler.column_indices[self]
except KeyError:
index = len(compiler.column_indices)
compiler.column_indices[self] = index
return f"$${compiler.PARENT_FIELD_TEMPLATE.format(index)}"
# Add the column's collection's alias for columns in joined collections.
prefix = f"{self.alias}." if self.alias != compiler.collection_name else ""
return f"${prefix}{self.target.column}"
Expand Down Expand Up @@ -79,8 +92,73 @@ def order_by(self, compiler, connection):
return self.expression.as_mql(compiler, connection)


def query(self, compiler, connection): # noqa: ARG001
raise NotSupportedError("Using a QuerySet in annotate() is not supported on MongoDB.")
def query(self, compiler, connection, lookup_name=None):
subquery_compiler = self.get_compiler(connection=connection)
subquery_compiler.pre_sql_setup(with_col_aliases=False)
columns = subquery_compiler.get_columns()
field_name, expr = columns[0]
subquery = subquery_compiler.build_query(
columns
if subquery_compiler.query.annotations or not subquery_compiler.query.default_cols
else None
)
table_output = f"__subquery{len(compiler.subqueries)}"
from_table = next(
e.table_name for alias, e in self.alias_map.items() if self.alias_refcount[alias]
)
# To perform a subquery, a $lookup stage that escapsulates the entire
# subquery pipeline is added. The "let" clause defines the variables
# needed to bridge the main collection with the subquery.
subquery.subquery_lookup = {
"as": table_output,
"from": from_table,
"let": {
compiler.PARENT_FIELD_TEMPLATE.format(i): col.as_mql(compiler, connection)
for col, i in subquery_compiler.column_indices.items()
},
}
# The result must be a list of values. The output is compressed with an
# aggregation pipeline.
if lookup_name in ("in", "range"):
if subquery.aggregation_pipeline is None:
subquery.aggregation_pipeline = []
subquery.aggregation_pipeline.extend(
[
{
"$facet": {
"group": [
{
"$group": {
"_id": None,
"tmp_name": {
"$addToSet": expr.as_mql(subquery_compiler, connection)
},
}
}
]
}
},
{
"$project": {
field_name: {
"$ifNull": [
{
"$getField": {
"input": {"$arrayElemAt": ["$group", 0]},
"field": "tmp_name",
}
},
[],
]
}
}
},
]
)
# Erase project_fields since the required value is projected above.
subquery.project_fields = None
compiler.subqueries.append(subquery)
return f"${table_output}.{field_name}"


def raw_sql(self, compiler, connection): # noqa: ARG001
Expand All @@ -100,8 +178,16 @@ def star(self, compiler, connection): # noqa: ARG001
return {"$literal": True}


def subquery(self, compiler, connection): # noqa: ARG001
raise NotSupportedError(f"{self.__class__.__name__} is not supported on MongoDB.")
def subquery(self, compiler, connection, lookup_name=None):
return self.query.as_mql(compiler, connection, lookup_name=lookup_name)


def exists(self, compiler, connection, lookup_name=None):
try:
lhs_mql = subquery(self, compiler, connection, lookup_name=lookup_name)
except EmptyResultSet:
return Value(False).as_mql(compiler, connection)
return connection.mongo_operators["isnull"](lhs_mql, False)


def when(self, compiler, connection):
Expand Down Expand Up @@ -130,6 +216,7 @@ def register_expressions():
Case.as_mql = case
Col.as_mql = col
CombinedExpression.as_mql = combined_expression
Exists.as_mql = exists
ExpressionWrapper.as_mql = expression_wrapper
F.as_mql = f
NegatedExpression.as_mql = negated_expression
Expand Down
Loading