Skip to content

Commit 1ca96d9

Browse files
committed
Cast join condition when the types are differents.
1 parent 7f53f46 commit 1ca96d9

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

django_mongodb/operations.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from django.conf import settings
99
from django.db import DataError
1010
from django.db.backends.base.operations import BaseDatabaseOperations
11+
from django.db.models import TextField
1112
from django.db.models.expressions import Combinable
13+
from django.db.models.functions import Cast
1214
from django.utils import timezone
1315
from django.utils.regex_helper import _lazy_re_compile
1416

@@ -254,3 +256,17 @@ def format_for_duration_arithmetic(self, sql):
254256

255257
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
256258
return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params)
259+
260+
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
261+
lhs_expr, rhs_expr = super().prepare_join_on_clause(
262+
lhs_table, lhs_field, rhs_table, rhs_field
263+
)
264+
265+
# If the types are different, we cast both to string.
266+
if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection):
267+
if lhs_field.db_type(self.connection) != "string":
268+
lhs_expr = Cast(lhs_expr, output_field=TextField())
269+
if rhs_field.db_type(self.connection) != "string":
270+
rhs_expr = Cast(rhs_expr, output_field=TextField())
271+
272+
return lhs_expr, rhs_expr

django_mongodb/query.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,12 @@ def join(self, compiler, connection):
9999
# Add a join condition for each pair of joining fields.
100100
for lhs, rhs in self.join_fields:
101101
lhs, rhs = connection.ops.prepare_join_on_clause(
102-
self.parent_alias, lhs, self.table_name, rhs
102+
self.parent_alias, lhs, compiler.collection_name, rhs
103103
)
104104
lhs_fields.append(lhs.as_mql(compiler, connection))
105105
# In the lookup stage, the reference to this column doesn't include
106106
# the collection name.
107-
rhs_fields.append(rhs.as_mql(compiler, connection).replace(f"{self.table_name}.", "", 1))
107+
rhs_fields.append(rhs.as_mql(compiler, connection))
108108

109109
parent_template = "parent__field__"
110110
lookup_pipeline = [
@@ -115,7 +115,7 @@ def join(self, compiler, connection):
115115
# The pipeline variables to be matched in the pipeline's
116116
# expression.
117117
"let": {
118-
f"{parent_template}{i}": {"$toString": parent_field}
118+
f"{parent_template}{i}": parent_field
119119
for i, parent_field in enumerate(lhs_fields)
120120
},
121121
"pipeline": [
@@ -129,7 +129,7 @@ def join(self, compiler, connection):
129129
"$match": {
130130
"$expr": {
131131
"$and": [
132-
{"$eq": [f"$${parent_template}{i}", {"$toString": field}]}
132+
{"$eq": [f"$${parent_template}{i}", field]}
133133
for i, field in enumerate(rhs_fields)
134134
]
135135
}

0 commit comments

Comments
 (0)