diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 3bb1324ae..6d3d35800 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -1,4 +1,7 @@ +import difflib + from django.core import checks +from django.core.exceptions import FieldDoesNotExist from django.db import models from django.db.models.fields.related import lazy_related_operation from django.db.models.lookups import Transform @@ -123,7 +126,8 @@ def get_transform(self, name): transform = super().get_transform(name) if transform: return transform - return KeyTransformFactory(name) + field = self.embedded_model._meta.get_field(name) + return KeyTransformFactory(name, field) def validate(self, value, model_instance): super().validate(value, model_instance) @@ -145,9 +149,36 @@ def formfield(self, **kwargs): class KeyTransform(Transform): - def __init__(self, key_name, *args, **kwargs): + def __init__(self, key_name, ref_field, *args, **kwargs): super().__init__(*args, **kwargs) self.key_name = str(key_name) + self.ref_field = ref_field + + def get_transform(self, name): + """ + Validate that `name` is either a field of an embedded model or a + lookup on an embedded model's field. + """ + result = None + if isinstance(self.ref_field, EmbeddedModelField): + opts = self.ref_field.embedded_model._meta + new_field = opts.get_field(name) + result = KeyTransformFactory(name, new_field) + else: + if self.ref_field.get_transform(name) is None: + suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups()) + if suggested_lookups: + suggested_lookups = " or ".join(suggested_lookups) + suggestion = f", perhaps you meant {suggested_lookups}?" + else: + suggestion = "." + raise FieldDoesNotExist( + f"Unsupported lookup '{name}' for " + f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'" + f"{suggestion}" + ) + result = KeyTransformFactory(name, self.ref_field) + return result def preprocess_lhs(self, compiler, connection): key_transforms = [self.key_name] @@ -165,8 +196,9 @@ def as_mql(self, compiler, connection): class KeyTransformFactory: - def __init__(self, key_name): + def __init__(self, key_name, ref_field): self.key_name = key_name + self.ref_field = ref_field def __call__(self, *args, **kwargs): - return KeyTransform(self.key_name, *args, **kwargs) + return KeyTransform(self.key_name, self.ref_field, *args, **kwargs) diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index 10cd84f89..2f5fb1d27 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -1,6 +1,6 @@ import operator -from django.core.exceptions import ValidationError +from django.core.exceptions import FieldDoesNotExist, ValidationError from django.db import models from django.db.models import ExpressionWrapper, F, Max, Sum from django.test import SimpleTestCase, TestCase @@ -147,6 +147,41 @@ def test_nested(self): self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj]) +class InvalidLookupTests(SimpleTestCase): + def test_invalid_field(self): + msg = "Author has no field named 'first_name'" + with self.assertRaisesMessage(FieldDoesNotExist, msg): + Book.objects.filter(author__first_name="Bob") + + def test_invalid_field_nested(self): + msg = "Address has no field named 'floor'" + with self.assertRaisesMessage(FieldDoesNotExist, msg): + Book.objects.filter(author__address__floor="NYC") + + def test_invalid_lookup(self): + msg = "Unsupported lookup 'foo' for CharField 'city'." + with self.assertRaisesMessage(FieldDoesNotExist, msg): + Book.objects.filter(author__address__city__foo="NYC") + + def test_invalid_lookup_with_suggestions(self): + msg = ( + "Unsupported lookup '{lookup}' for CharField 'name', " + "perhaps you meant {suggested_lookups}?" + ) + with self.assertRaisesMessage( + FieldDoesNotExist, msg.format(lookup="exactly", suggested_lookups="exact or iexact") + ): + Book.objects.filter(author__name__exactly="NYC") + with self.assertRaisesMessage( + FieldDoesNotExist, msg.format(lookup="gti", suggested_lookups="gt or gte") + ): + Book.objects.filter(author__name__gti="NYC") + with self.assertRaisesMessage( + FieldDoesNotExist, msg.format(lookup="is_null", suggested_lookups="isnull") + ): + Book.objects.filter(author__name__is_null="NYC") + + @isolate_apps("model_fields_") class CheckTests(SimpleTestCase): def test_no_relational_fields(self):