Skip to content

Commit 78a8619

Browse files
WaVEVtimgraham
authored andcommitted
fix incorrect GenericRelation joining
By adding support for Field.get_extra_restriction().
1 parent 5b92640 commit 78a8619

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

django_mongodb/features.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
6868
"aggregation.tests.AggregateTestCase.test_reverse_fkey_annotate",
6969
"aggregation_regress.tests.AggregationTests.test_annotation_disjunction",
7070
"aggregation_regress.tests.AggregationTests.test_decimal_aggregate_annotation_filter",
71-
# Incorrect JOIN with GenericRelation gives incorrect results.
72-
"aggregation_regress.tests.AggregationTests.test_aggregation_with_generic_reverse_relation",
73-
"generic_relations.tests.GenericRelationsTests.test_queries_content_type_restriction",
71+
# Wrong result for GenericRelation annotation.
7472
"generic_relations_regress.tests.GenericRelationTests.test_annotate",
7573
# subclasses of BaseDatabaseWrapper may require an is_usable() method
7674
"backends.tests.BackendTestCase.test_is_usable_after_database_disconnects",

django_mongodb/query.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from django.core.exceptions import EmptyResultSet, FullResultSet
55
from django.db import DatabaseError, IntegrityError, NotSupportedError
6-
from django.db.models.expressions import Case, When
6+
from django.db.models.expressions import Case, Col, When
77
from django.db.models.functions import Mod
88
from django.db.models.lookups import Exact
99
from django.db.models.sql.constants import INNER
@@ -105,6 +105,7 @@ def join(self, compiler, connection):
105105
lhs_fields = []
106106
rhs_fields = []
107107
# Add a join condition for each pair of joining fields.
108+
parent_template = "parent__field__"
108109
for lhs, rhs in self.join_fields:
109110
lhs, rhs = connection.ops.prepare_join_on_clause(
110111
self.parent_alias, lhs, compiler.collection_name, rhs
@@ -113,8 +114,41 @@ def join(self, compiler, connection):
113114
# In the lookup stage, the reference to this column doesn't include
114115
# the collection name.
115116
rhs_fields.append(rhs.as_mql(compiler, connection))
117+
# Handle any join conditions besides matching field pairs.
118+
extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
119+
if extra:
120+
columns = []
121+
for expr in extra.leaves():
122+
# Determine whether the column needs to be transformed or rerouted
123+
# as part of the subquery.
124+
for hand_side in ["lhs", "rhs"]:
125+
hand_side_value = getattr(expr, hand_side, None)
126+
if isinstance(hand_side_value, Col):
127+
# If the column is not part of the joined table, add it to
128+
# lhs_fields.
129+
if hand_side_value.alias != self.table_name:
130+
pos = len(lhs_fields)
131+
lhs_fields.append(expr.lhs.as_mql(compiler, connection))
132+
else:
133+
pos = None
134+
columns.append((hand_side_value, pos))
135+
# Replace columns in the extra conditions with new column references
136+
# based on their rerouted positions in the join pipeline.
137+
replacements = {}
138+
for col, parent_pos in columns:
139+
column_target = Col(compiler.collection_name, expr.output_field.__class__())
140+
if parent_pos is not None:
141+
target_col = f"${parent_template}{parent_pos}"
142+
column_target.target.db_column = target_col
143+
column_target.target.set_attributes_from_name(target_col)
144+
else:
145+
column_target.target = col.target
146+
replacements[col] = column_target
147+
# Apply the transformed expressions in the extra condition.
148+
extra_condition = [extra.replace_expressions(replacements).as_mql(compiler, connection)]
149+
else:
150+
extra_condition = []
116151

117-
parent_template = "parent__field__"
118152
lookup_pipeline = [
119153
{
120154
"$lookup": {
@@ -140,6 +174,7 @@ def join(self, compiler, connection):
140174
{"$eq": [f"$${parent_template}{i}", field]}
141175
for i, field in enumerate(rhs_fields)
142176
]
177+
+ extra_condition
143178
}
144179
}
145180
}

0 commit comments

Comments
 (0)