Skip to content

Commit adc46aa

Browse files
committed
Edits.
1 parent a556c35 commit adc46aa

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

django_mongodb/expressions.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ def case(self, compiler, connection):
5353
def col(self, compiler, connection): # noqa: ARG001
5454
# If it is a subquery and the columns belongs to one of the ancestors,
5555
# the column shall be stored to be passed using $let in a $lookup stage.
56-
if self.alias in compiler.parent_collections:
56+
if (
57+
self.alias not in compiler.query.alias_refcount
58+
or compiler.query.alias_refcount[self.alias] == 0
59+
):
5760
try:
5861
index = compiler.column_mapping[self]
5962
except KeyError:
@@ -89,10 +92,9 @@ def order_by(self, compiler, connection):
8992
return self.expression.as_mql(compiler, connection)
9093

9194

92-
def query(self, compiler, connection):
95+
def query(self, compiler, connection, lookup_name=None):
9396
subquery_compiler = self.get_compiler(connection=connection)
9497
subquery_compiler.pre_sql_setup(with_col_aliases=False)
95-
subquery_compiler.parent_collections = {compiler.collection_name} | compiler.parent_collections
9698
columns = subquery_compiler.get_columns()
9799
field_name, expr = columns[0]
98100
subquery = subquery_compiler.build_query(
@@ -112,8 +114,8 @@ def query(self, compiler, connection):
112114
for col, i in subquery_compiler.column_mapping.items()
113115
},
114116
}
115-
# the result must be a list of values. Se we compress the output with an aggregation pipeline.
116-
if not self.has_limit_one():
117+
# The result must be a list of values. Se we compress the output with an aggregation pipeline.
118+
if lookup_name in ("in", "range"):
117119
subquery.aggregation_pipeline = [
118120
{
119121
"$group": {
@@ -144,8 +146,8 @@ def star(self, compiler, connection): # noqa: ARG001
144146
return {"$literal": True}
145147

146148

147-
def subquery(self, compiler, connection):
148-
return self.query.as_mql(compiler, connection)
149+
def subquery(self, compiler, connection, lookup_name=None):
150+
return self.query.as_mql(compiler, connection, lookup_name=lookup_name)
149151

150152

151153
def exists(self, compiler, connection):

django_mongodb/query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ def get_cursor(self):
7676

7777
def get_pipeline(self):
7878
pipeline = []
79-
for query in self.subqueries or ():
80-
pipeline.extend(query.get_pipeline())
8179
if self.lookup_pipeline:
8280
pipeline.extend(self.lookup_pipeline)
81+
for query in self.subqueries or ():
82+
pipeline.extend(query.get_pipeline())
8383
if self.mongo_query:
8484
pipeline.append({"$match": self.mongo_query})
8585
if self.aggregation_pipeline:

django_mongodb/query_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ def process_lhs(node, compiler, connection):
2828
def process_rhs(node, compiler, connection):
2929
rhs = node.rhs
3030
if hasattr(rhs, "as_mql"):
31-
value = rhs.as_mql(compiler, connection)
31+
if getattr(rhs, "subquery", False):
32+
value = rhs.as_mql(compiler, connection, lookup_name=node.lookup_name)
33+
else:
34+
value = rhs.as_mql(compiler, connection)
3235
else:
3336
_, value = node.process_rhs(compiler, connection)
34-
lookup_name = node.lookup_name
3537
# Undo Lookup.get_db_prep_lookup() putting params in a list.
36-
if lookup_name not in ("in", "range"):
38+
if node.lookup_name not in ("in", "range"):
3739
value = value[0]
3840
if hasattr(node, "prep_lookup_value_mongo"):
3941
value = node.prep_lookup_value_mongo(value)

0 commit comments

Comments
 (0)