Skip to content

Commit e1df829

Browse files
committed
Add support to join extras.
1 parent 0935f03 commit e1df829

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

django_mongodb/query.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from django.core.exceptions import EmptyResultSet, FullResultSet
55
from django.db import DatabaseError, IntegrityError
6-
from django.db.models.expressions import Case, When
6+
from django.db.models.expressions import Case, Col, When
77
from django.db.models.functions import Mod
88
from django.db.models.lookups import Exact
99
from django.db.models.sql.constants import INNER
@@ -97,6 +97,7 @@ def join(self, compiler, connection):
9797
lhs_fields = []
9898
rhs_fields = []
9999
# Add a join condition for each pair of joining fields.
100+
parent_template = "parent__field__"
100101
for lhs, rhs in self.join_fields:
101102
lhs, rhs = connection.ops.prepare_join_on_clause(
102103
self.parent_alias, lhs, self.table_name, rhs
@@ -105,8 +106,28 @@ def join(self, compiler, connection):
105106
# In the lookup stage, the reference to this column doesn't include
106107
# the collection name.
107108
rhs_fields.append(rhs.as_mql(compiler, connection).replace(f"{self.table_name}.", "", 1))
109+
extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
110+
if extra:
111+
columns = []
112+
for expr in extra.leaves():
113+
if hasattr(expr, "lhs") and isinstance(expr.lhs, Col):
114+
columns.append((expr.lhs, len(lhs_fields)))
115+
lhs_fields.append(expr.lhs.as_mql(compiler, connection))
116+
if hasattr(expr, "rhs") and isinstance(expr.rhs, Col):
117+
columns.append((expr.rhs, None))
118+
replacements = {}
119+
for col, parent_pos in columns:
120+
column_target = col.copy()
121+
if column_target.alias == self.table_name:
122+
column_target.alias = compiler.collection_name
123+
else:
124+
column_target.target.db_column = f"{parent_template}{parent_pos}"
125+
column_target.target.set_attributes_from_name(f"{parent_template}{len(lhs_fields)}")
126+
replacements[col] = column_target
127+
extra_condition = [extra.replace_expressions(replacements).as_mql(compiler, connection)]
128+
else:
129+
extra_condition = []
108130

109-
parent_template = "parent__field__"
110131
lookup_pipeline = [
111132
{
112133
"$lookup": {
@@ -132,6 +153,7 @@ def join(self, compiler, connection):
132153
{"$eq": [f"$${parent_template}{i}", field]}
133154
for i, field in enumerate(rhs_fields)
134155
]
156+
+ extra_condition
135157
}
136158
}
137159
}

0 commit comments

Comments
 (0)