Skip to content

Commit 8b98b58

Browse files
WaVEVtimgraham
authored andcommitted
fix GenericRelation object_id / target id join type mismatch
1 parent d0c102e commit 8b98b58

File tree

3 files changed

+16
-22
lines changed

3 files changed

+16
-22
lines changed

django_mongodb/features.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,26 +87,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
8787
"backends.tests.ThreadTests.test_pass_connection_between_threads",
8888
"backends.tests.ThreadTests.test_closing_non_shared_connections",
8989
"backends.tests.ThreadTests.test_default_connection_thread_local",
90-
# GenericRelation join doesn't work due to type mismatch between
91-
# object_id (string) and target id (ObjectId) field.
92-
"generic_relations.tests.GenericRelationsTests.test_subclasses_with_gen_rel",
93-
"generic_relations.tests.GenericRelationsTests.test_subclasses_with_parent_gen_rel",
94-
"generic_relations.tests.ProxyRelatedModelTest.test_query",
95-
"generic_relations.tests.ProxyRelatedModelTest.test_query_proxy",
96-
"generic_relations.tests.GenericRelationsTests.test_access_via_content_type",
97-
"generic_relations.tests.GenericRelationsTests.test_generic_relation_to_inherited_child",
98-
"generic_relations.tests.GenericRelationsTests.test_query_content_object",
99-
"generic_relations_regress.tests.GenericRelationTests.test_filter_on_related_proxy_model",
100-
"generic_relations_regress.tests.GenericRelationTests.test_charlink_filter",
101-
"generic_relations_regress.tests.GenericRelationTests.test_filter_targets_related_pk",
102-
"generic_relations_regress.tests.GenericRelationTests.test_generic_reverse_relation_exclude_filter",
103-
"generic_relations_regress.tests.GenericRelationTests.test_generic_reverse_relation_with_abc",
104-
"generic_relations_regress.tests.GenericRelationTests.test_generic_reverse_relation_with_mti",
105-
"generic_relations_regress.tests.GenericRelationTests.test_reverse_relation_pk",
106-
"generic_relations_regress.tests.GenericRelationTests.test_textlink_filter",
107-
"generic_relations_regress.tests.GenericRelationTests.test_ticket_20378",
108-
"generic_relations_regress.tests.GenericRelationTests.test_ticket_20564",
109-
"generic_relations_regress.tests.GenericRelationTests.test_ticket_20564_nullable_fk",
11090
# AddField
11191
"schema.tests.SchemaTests.test_add_indexed_charfield",
11292
"schema.tests.SchemaTests.test_add_unique_charfield",

django_mongodb/operations.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from django.conf import settings
99
from django.db import DataError
1010
from django.db.backends.base.operations import BaseDatabaseOperations
11+
from django.db.models import TextField
1112
from django.db.models.expressions import Combinable
13+
from django.db.models.functions import Cast
1214
from django.utils import timezone
1315
from django.utils.regex_helper import _lazy_re_compile
1416

@@ -212,6 +214,18 @@ def explain_query_prefix(self, format=None, **options):
212214
super().explain_query_prefix(format, **options)
213215
return validated_options
214216

217+
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
218+
lhs_expr, rhs_expr = super().prepare_join_on_clause(
219+
lhs_table, lhs_field, rhs_table, rhs_field
220+
)
221+
# If the types are different, cast both to string.
222+
if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection):
223+
if lhs_field.db_type(self.connection) != "string":
224+
lhs_expr = Cast(lhs_expr, output_field=TextField())
225+
if rhs_field.db_type(self.connection) != "string":
226+
rhs_expr = Cast(rhs_expr, output_field=TextField())
227+
return lhs_expr, rhs_expr
228+
215229
"""Django uses these methods to generate SQL queries before it generates MQL queries."""
216230

217231
# EXTRACT format cannot be passed in parameters.

django_mongodb/query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,12 @@ def join(self, compiler, connection):
9999
# Add a join condition for each pair of joining fields.
100100
for lhs, rhs in self.join_fields:
101101
lhs, rhs = connection.ops.prepare_join_on_clause(
102-
self.parent_alias, lhs, self.table_name, rhs
102+
self.parent_alias, lhs, compiler.collection_name, rhs
103103
)
104104
lhs_fields.append(lhs.as_mql(compiler, connection))
105105
# In the lookup stage, the reference to this column doesn't include
106106
# the collection name.
107-
rhs_fields.append(rhs.as_mql(compiler, connection).replace(f"{self.table_name}.", "", 1))
107+
rhs_fields.append(rhs.as_mql(compiler, connection))
108108

109109
parent_template = "parent__field__"
110110
lookup_pipeline = [

0 commit comments

Comments
 (0)