Skip to content

Commit a92d9d7

Browse files
committed
First approach.
1 parent fba8892 commit a92d9d7

File tree

3 files changed

+95
-10
lines changed

3 files changed

+95
-10
lines changed

django_mongodb/compiler.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from django.db.models.lookups import IsNull
1414
from django.db.models.sql import compiler
1515
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
16+
from django.db.models.sql.datastructures import BaseTable
1617
from django.utils.functional import cached_property
1718
from pymongo import ASCENDING, DESCENDING
1819

@@ -25,12 +26,20 @@ class SQLCompiler(compiler.SQLCompiler):
2526

2627
query_class = MongoQuery
2728
GROUP_SEPARATOR = "___"
29+
PARENT_FIELD_TEMPLATE = "parent__field__{}"
2830

2931
def __init__(self, *args, **kwargs):
3032
super().__init__(*args, **kwargs)
3133
self.aggregation_pipeline = None
3234
# A list of OrderBy objects for this query.
3335
self.order_by_objs = None
36+
# Subquery parent compiler.
37+
self.parent_collections = set()
38+
self.column_mapping = {}
39+
self.subqueries = []
40+
41+
def get_parent(self):
42+
return self.parent_compiler
3443

3544
def _unfold_column(self, col):
3645
"""
@@ -388,6 +397,7 @@ def build_query(self, columns=None):
388397
query.mongo_query = {"$expr": expr}
389398
if extra_fields:
390399
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
400+
query.subqueries = self.subqueries
391401
return query
392402

393403
def get_columns(self):
@@ -431,7 +441,8 @@ def project_field(column):
431441

432442
@cached_property
433443
def collection_name(self):
434-
return self.query.get_meta().db_table
444+
base_table = next(v for v in self.query.alias_map.values() if isinstance(v, BaseTable))
445+
return base_table.table_alias or base_table.table_name
435446

436447
@cached_property
437448
def collection(self):
@@ -581,7 +592,7 @@ def _get_ordering(self):
581592
return tuple(fields), sort_ordering, tuple(extra_fields)
582593

583594
def get_where(self):
584-
return self.where
595+
return getattr(self, "where", self.query.where)
585596

586597
def explain_query(self):
587598
# Validate format (none supported) and options.
@@ -741,7 +752,7 @@ def build_query(self, columns=None):
741752
else None
742753
)
743754
subquery = compiler.build_query(columns)
744-
query.subquery = subquery
755+
query.subqueries = [subquery]
745756
return query
746757

747758
def _make_result(self, result, columns=None):

django_mongodb/expressions.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Case,
1010
Col,
1111
CombinedExpression,
12+
Exists,
1213
ExpressionWrapper,
1314
F,
1415
NegatedExpression,
@@ -50,6 +51,15 @@ def case(self, compiler, connection):
5051

5152

5253
def col(self, compiler, connection): # noqa: ARG001
54+
# If it is a subquery and the columns belongs to one of the ancestors,
55+
# the column shall be stored to be passed using $let in a $lookup stage.
56+
if self.alias in compiler.parent_collections:
57+
try:
58+
index = compiler.column_mapping[self]
59+
except KeyError:
60+
index = len(compiler.column_mapping)
61+
compiler.column_mapping[self] = index
62+
return f"$${compiler.PARENT_FIELD_TEMPLATE.format(index)}"
5363
# Add the column's collection's alias for columns in joined collections.
5464
prefix = f"{self.alias}." if self.alias != compiler.collection_name else ""
5565
return f"${prefix}{self.target.column}"
@@ -79,8 +89,64 @@ def order_by(self, compiler, connection):
7989
return self.expression.as_mql(compiler, connection)
8090

8191

82-
def query(self, compiler, connection): # noqa: ARG001
83-
raise NotSupportedError("Using a QuerySet in annotate() is not supported on MongoDB.")
92+
def query(self, compiler, connection):
93+
subquery_compiler = self.get_compiler(connection=connection)
94+
subquery_compiler.pre_sql_setup(with_col_aliases=False)
95+
subquery_compiler.parent_collections = {compiler.collection_name} | compiler.parent_collections
96+
columns = subquery_compiler.get_columns()
97+
field_name, expr = columns[0]
98+
subquery = subquery_compiler.build_query(
99+
columns
100+
if subquery_compiler.query.annotations or not subquery_compiler.query.default_cols
101+
else None
102+
)
103+
table_output = f"__subquery{len(compiler.subqueries)}"
104+
result_query = compiler.query_class(compiler)
105+
pipeline = subquery.get_pipeline()
106+
# the result must be a list of values. Se we compress the output
107+
if not self.has_limit_one():
108+
pipeline.extend(
109+
[
110+
{
111+
"$group": {
112+
"_id": None,
113+
"dummy_name": {"$addToSet": expr.as_mql(subquery_compiler, connection)},
114+
}
115+
},
116+
{"$project": {field_name: "$dummy_name"}},
117+
]
118+
)
119+
result_query.lookup_pipeline = [
120+
{
121+
"$lookup": {
122+
"from": self.get_meta().db_table,
123+
"pipeline": pipeline,
124+
"as": table_output,
125+
"let": {
126+
compiler.PARENT_FIELD_TEMPLATE.format(i): col.as_mql(compiler, connection)
127+
for col, i in subquery_compiler.column_mapping.items()
128+
},
129+
}
130+
},
131+
{
132+
"$set": {
133+
table_output: {
134+
"$cond": {
135+
"if": {
136+
"$or": [
137+
{"$eq": [{"$type": f"${table_output}"}, "missing"]},
138+
{"$eq": [{"$size": f"${table_output}"}, 0]},
139+
]
140+
},
141+
"then": {},
142+
"else": {"$arrayElemAt": [f"${table_output}", 0]},
143+
}
144+
}
145+
}
146+
},
147+
]
148+
compiler.subqueries.append(result_query)
149+
return f"${table_output}.{field_name}"
84150

85151

86152
def raw_sql(self, compiler, connection): # noqa: ARG001
@@ -100,8 +166,13 @@ def star(self, compiler, connection): # noqa: ARG001
100166
return {"$literal": True}
101167

102168

103-
def subquery(self, compiler, connection): # noqa: ARG001
104-
raise NotSupportedError(f"{self.__class__.__name__} is not supported on MongoDB.")
169+
def subquery(self, compiler, connection):
170+
return self.query.as_mql(compiler, connection)
171+
172+
173+
def exists(self, compiler, connection):
174+
lhs_mql = subquery(self, compiler, connection)
175+
return connection.mongo_operators["isnull"](lhs_mql, False)
105176

106177

107178
def when(self, compiler, connection):
@@ -130,6 +201,7 @@ def register_expressions():
130201
Case.as_mql = case
131202
Col.as_mql = col
132203
CombinedExpression.as_mql = combined_expression
204+
Exists.as_mql = exists
133205
ExpressionWrapper.as_mql = expression_wrapper
134206
F.as_mql = f
135207
NegatedExpression.as_mql = negated_expression

django_mongodb/query.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self, compiler):
5050
self.collection = self.compiler.collection
5151
self.collection_name = self.compiler.collection_name
5252
self.mongo_query = getattr(compiler.query, "raw_query", {})
53-
self.subquery = None
53+
self.subqueries = None
5454
self.lookup_pipeline = None
5555
self.project_fields = None
5656
self.aggregation_pipeline = compiler.aggregation_pipeline
@@ -74,7 +74,9 @@ def get_cursor(self):
7474
return self.collection.aggregate(self.get_pipeline())
7575

7676
def get_pipeline(self):
77-
pipeline = self.subquery.get_pipeline() if self.subquery else []
77+
pipeline = []
78+
for query in self.subqueries or ():
79+
pipeline.extend(query.get_pipeline())
7880
if self.lookup_pipeline:
7981
pipeline.extend(self.lookup_pipeline)
8082
if self.mongo_query:
@@ -269,7 +271,7 @@ def where_node(self, compiler, connection):
269271
raise FullResultSet
270272

271273
if self.negated and mql:
272-
mql = {"$eq": [mql, {"$literal": False}]}
274+
mql = {"$not": mql}
273275

274276
return mql
275277

0 commit comments

Comments
 (0)