Skip to content

Commit 0f2beae

Browse files
committed
Add support to join extras.
1 parent 27ed098 commit 0f2beae

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

django_mongodb/query.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from django.core.exceptions import EmptyResultSet, FullResultSet
55
from django.db import DatabaseError, IntegrityError, NotSupportedError
6-
from django.db.models.expressions import Case, When
6+
from django.db.models.expressions import Case, Col, When
77
from django.db.models.functions import Mod
88
from django.db.models.lookups import Exact
99
from django.db.models.sql.constants import INNER
@@ -101,6 +101,7 @@ def join(self, compiler, connection):
101101
lhs_fields = []
102102
rhs_fields = []
103103
# Add a join condition for each pair of joining fields.
104+
parent_template = "parent__field__"
104105
for lhs, rhs in self.join_fields:
105106
lhs, rhs = connection.ops.prepare_join_on_clause(
106107
self.parent_alias, lhs, compiler.collection_name, rhs
@@ -109,8 +110,28 @@ def join(self, compiler, connection):
109110
# In the lookup stage, the reference to this column doesn't include
110111
# the collection name.
111112
rhs_fields.append(rhs.as_mql(compiler, connection))
113+
extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
114+
if extra:
115+
columns = []
116+
for expr in extra.leaves():
117+
if hasattr(expr, "lhs") and isinstance(expr.lhs, Col):
118+
columns.append((expr.lhs, len(lhs_fields)))
119+
lhs_fields.append(expr.lhs.as_mql(compiler, connection))
120+
if hasattr(expr, "rhs") and isinstance(expr.rhs, Col):
121+
columns.append((expr.rhs, None))
122+
replacements = {}
123+
for col, parent_pos in columns:
124+
column_target = col.copy()
125+
if column_target.alias == self.table_name:
126+
column_target.alias = compiler.collection_name
127+
else:
128+
column_target.target.db_column = f"{parent_template}{parent_pos}"
129+
column_target.target.set_attributes_from_name(f"{parent_template}{len(lhs_fields)}")
130+
replacements[col] = column_target
131+
extra_condition = [extra.replace_expressions(replacements).as_mql(compiler, connection)]
132+
else:
133+
extra_condition = []
112134

113-
parent_template = "parent__field__"
114135
lookup_pipeline = [
115136
{
116137
"$lookup": {
@@ -136,6 +157,7 @@ def join(self, compiler, connection):
136157
{"$eq": [f"$${parent_template}{i}", {"$toString": field}]}
137158
for i, field in enumerate(rhs_fields)
138159
]
160+
+ extra_condition
139161
}
140162
}
141163
}

0 commit comments

Comments
 (0)