Skip to content

Commit 390da39

Browse files
committed
Convert conditions into a whereNode
1 parent 9731054 commit 390da39

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

django_mongodb_backend/compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from django.db.models.sql import compiler
1414
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
1515
from django.db.models.sql.datastructures import BaseTable
16-
from django.db.models.sql.where import AND
16+
from django.db.models.sql.where import AND, WhereNode
1717
from django.utils.functional import cached_property
1818
from pymongo import ASCENDING, DESCENDING
1919

@@ -556,15 +556,15 @@ def get_lookup_pipeline(self):
556556
if (
557557
isinstance(expr, Lookup)
558558
and isinstance(expr.lhs, Col)
559-
and (is_direct_value(expr.rhs) or isinstance(expr.rhs, Value))
559+
and (is_direct_value(expr.rhs) or isinstance(expr.rhs, Value | Col))
560560
):
561561
promote_filters[expr.lhs.alias].append(expr)
562562

563563
for alias in tuple(self.query.alias_map):
564564
if not self.query.alias_refcount[alias] or self.collection_name == alias:
565565
continue
566566
result += self.query.alias_map[alias].as_mql(
567-
self, self.connection, promote_filters[alias]
567+
self, self.connection, WhereNode(promote_filters[alias], connector=AND)
568568
)
569569
return result
570570

django_mongodb_backend/query.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,13 @@ def extra_where(self, compiler, connection): # noqa: ARG001
124124

125125

126126
def join(self, compiler, connection, pushed_expressions=None):
127+
parent_template = "parent__field__"
128+
127129
def _get_reroot_replacements(expressions):
128130
if not expressions:
129131
return None
130132
columns = []
131-
for expr in expressions:
133+
for expr in expressions.leaves():
132134
# Determine whether the column needs to be transformed or rerouted
133135
# as part of the subquery.
134136
for hand_side in ["lhs", "rhs"]:
@@ -162,7 +164,6 @@ def _get_reroot_replacements(expressions):
162164
lhs_fields = []
163165
rhs_fields = []
164166
# Add a join condition for each pair of joining fields.
165-
parent_template = "parent__field__"
166167
for lhs, rhs in self.join_fields:
167168
lhs, rhs = connection.ops.prepare_join_on_clause(
168169
self.parent_alias, lhs, compiler.collection_name, rhs
@@ -174,19 +175,20 @@ def _get_reroot_replacements(expressions):
174175
# Handle any join conditions besides matching field pairs.
175176
extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
176177

178+
extra_conditions = []
177179
if extra:
178-
replacements = _get_reroot_replacements(extra.leaves())
179-
extra_condition = [extra.replace_expressions(replacements).as_mql(compiler, connection)]
180-
else:
181-
extra_condition = []
182-
if self.join_type == INNER:
180+
replacements = _get_reroot_replacements(extra)
181+
extra_conditions.append(
182+
extra.replace_expressions(replacements).as_mql(compiler, connection)
183+
)
184+
185+
if pushed_expressions and self.join_type == INNER:
183186
rerooted_replacement = _get_reroot_replacements(pushed_expressions)
184-
resolved_pushed_expressions = [
185-
expr.replace_expressions(rerooted_replacement).as_mql(compiler, connection)
186-
for expr in pushed_expressions
187-
]
188-
else:
189-
resolved_pushed_expressions = []
187+
extra_conditions.append(
188+
pushed_expressions.replace_expressions(rerooted_replacement).as_mql(
189+
compiler, connection
190+
)
191+
)
190192

191193
lookup_pipeline = [
192194
{
@@ -213,8 +215,7 @@ def _get_reroot_replacements(expressions):
213215
{"$eq": [f"$${parent_template}{i}", field]}
214216
for i, field in enumerate(rhs_fields)
215217
]
216-
+ extra_condition
217-
+ resolved_pushed_expressions
218+
+ extra_conditions
218219
}
219220
}
220221
}

0 commit comments

Comments
 (0)