diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 1aa64ae42..a870b0e8d 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -83,6 +83,8 @@ jobs: defer defer_regress from_db_value + generic_relations + generic_relations_regress introspection known_related_objects lookup diff --git a/django_mongodb/features.py b/django_mongodb/features.py index 2c3a00a98..72e8b2706 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -79,6 +79,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): "many_to_one.tests.ManyToOneTests.test_selects", # Incorrect JOIN with GenericRelation gives incorrect results. "aggregation_regress.tests.AggregationTests.test_aggregation_with_generic_reverse_relation", + "generic_relations.tests.GenericRelationsTests.test_queries_content_type_restriction", + "generic_relations_regress.tests.GenericRelationTests.test_annotate", # subclasses of BaseDatabaseWrapper may require an is_usable() method "backends.tests.BackendTestCase.test_is_usable_after_database_disconnects", # Connection creation doesn't follow the usual Django API. diff --git a/django_mongodb/operations.py b/django_mongodb/operations.py index bf5ed8bb8..04ba5fcbf 100644 --- a/django_mongodb/operations.py +++ b/django_mongodb/operations.py @@ -8,7 +8,9 @@ from django.conf import settings from django.db import DataError from django.db.backends.base.operations import BaseDatabaseOperations +from django.db.models import TextField from django.db.models.expressions import Combinable +from django.db.models.functions import Cast from django.utils import timezone from django.utils.regex_helper import _lazy_re_compile @@ -212,6 +214,18 @@ def explain_query_prefix(self, format=None, **options): super().explain_query_prefix(format, **options) return validated_options + def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field): + lhs_expr, rhs_expr = super().prepare_join_on_clause( + lhs_table, lhs_field, rhs_table, rhs_field + ) + # If the types are different, cast both to string. + if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection): + if lhs_field.db_type(self.connection) != "string": + lhs_expr = Cast(lhs_expr, output_field=TextField()) + if rhs_field.db_type(self.connection) != "string": + rhs_expr = Cast(rhs_expr, output_field=TextField()) + return lhs_expr, rhs_expr + """Django uses these methods to generate SQL queries before it generates MQL queries.""" # EXTRACT format cannot be passed in parameters. diff --git a/django_mongodb/query.py b/django_mongodb/query.py index 7a7848cd0..25d857bc1 100644 --- a/django_mongodb/query.py +++ b/django_mongodb/query.py @@ -99,12 +99,12 @@ def join(self, compiler, connection): # Add a join condition for each pair of joining fields. for lhs, rhs in self.join_fields: lhs, rhs = connection.ops.prepare_join_on_clause( - self.parent_alias, lhs, self.table_name, rhs + self.parent_alias, lhs, compiler.collection_name, rhs ) lhs_fields.append(lhs.as_mql(compiler, connection)) # In the lookup stage, the reference to this column doesn't include # the collection name. - rhs_fields.append(rhs.as_mql(compiler, connection).replace(f"{self.table_name}.", "", 1)) + rhs_fields.append(rhs.as_mql(compiler, connection)) parent_template = "parent__field__" lookup_pipeline = [