diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index d9dd5b6cf..cf45eac8a 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -7,7 +7,6 @@ from django.db.models.lookups import Transform from .. import forms -from .json import build_json_mql_path class EmbeddedModelField(models.Field): @@ -155,54 +154,41 @@ def __init__(self, key_name, ref_field, *args, **kwargs): self.key_name = str(key_name) self.ref_field = ref_field + def get_lookup(self, name): + return self.ref_field.get_lookup(name) + def get_transform(self, name): """ Validate that `name` is either a field of an embedded model or a lookup on an embedded model's field. """ - result = None - if isinstance(self.ref_field, EmbeddedModelField): - opts = self.ref_field.embedded_model._meta - new_field = opts.get_field(name) - result = KeyTransformFactory(name, new_field) + if transform := self.ref_field.get_transform(name): + return transform + suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups()) + if suggested_lookups: + suggested_lookups = " or ".join(suggested_lookups) + suggestion = f", perhaps you meant {suggested_lookups}?" else: - if self.ref_field.get_transform(name) is None: - suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups()) - if suggested_lookups: - suggested_lookups = " or ".join(suggested_lookups) - suggestion = f", perhaps you meant {suggested_lookups}?" - else: - suggestion = "." - raise FieldDoesNotExist( - f"Unsupported lookup '{name}' for " - f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'" - f"{suggestion}" - ) - result = KeyTransformFactory(name, self.ref_field) - return result + suggestion = "." + raise FieldDoesNotExist( + f"Unsupported lookup '{name}' for " + f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'" + f"{suggestion}" + ) - def preprocess_lhs(self, compiler, connection): + def as_mql(self, compiler, connection): previous = self - embedded_key_transforms = [] - json_key_transforms = [] + key_transforms = [] while isinstance(previous, KeyTransform): - if isinstance(previous.ref_field, EmbeddedModelField): - embedded_key_transforms.insert(0, previous.key_name) - else: - json_key_transforms.insert(0, previous.key_name) + key_transforms.insert(0, previous.key_name) previous = previous.lhs mql = previous.as_mql(compiler, connection) - # The first json_key_transform is the field name. - embedded_key_transforms.append(json_key_transforms.pop(0)) - return mql, embedded_key_transforms, json_key_transforms - - def as_mql(self, compiler, connection): - mql, key_transforms, json_key_transforms = self.preprocess_lhs(compiler, connection) transforms = ".".join(key_transforms) - result = f"{mql}.{transforms}" - if json_key_transforms: - result = build_json_mql_path(result, json_key_transforms) - return result + return f"{mql}.{transforms}" + + @property + def output_field(self): + return self.ref_field class KeyTransformFactory: diff --git a/docs/source/releases/5.1.x.rst b/docs/source/releases/5.1.x.rst index 9ff9fe9a8..dde2e03a2 100644 --- a/docs/source/releases/5.1.x.rst +++ b/docs/source/releases/5.1.x.rst @@ -2,6 +2,15 @@ Django MongoDB Backend 5.1.x ============================ +5.1.0 beta 3 +============ + +*Unreleased* + +- Added support for a field's custom lookups and transforms in + ``EmbeddedModelField``, e.g. ``ArrayField``’s ``contains``, + ``contained__by``, ``len``, etc. + .. _django-mongodb-backend-5.1.0-beta-2: 5.1.0 beta 2 diff --git a/docs/source/releases/5.2.x.rst b/docs/source/releases/5.2.x.rst index 02a4e43ed..61b6b9e61 100644 --- a/docs/source/releases/5.2.x.rst +++ b/docs/source/releases/5.2.x.rst @@ -12,3 +12,10 @@ Initial release from the state of :ref:`django-mongodb-backend 5.1.0 beta 2 Regarding new features in Django 5.2, :class:`~django.db.models.CompositePrimaryKey` isn't supported. + +Bug fixes +--------- + +- Added support for a field's custom lookups and transforms in + ``EmbeddedModelField``, e.g. ``ArrayField``’s ``contains``, + ``contained__by``, ``len``, etc. diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index b25b94a1c..ad573323b 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -110,12 +110,14 @@ class Address(EmbeddedModel): city = models.CharField(max_length=20) state = models.CharField(max_length=2) zip_code = models.IntegerField(db_index=True) + tags = ArrayField(models.CharField(max_length=100), null=True, blank=True) class Author(EmbeddedModel): name = models.CharField(max_length=10) age = models.IntegerField() address = EmbeddedModelField(Address) + skills = ArrayField(models.CharField(max_length=100), null=True, blank=True) class Book(models.Model): diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index eee0dd1a9..700a3cf1c 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -186,6 +186,56 @@ def test_nested(self): self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj]) +class ArrayFieldTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.book = Book.objects.create( + author=Author( + name="Shakespeare", + age=55, + skills=["writing", "editing"], + address=Address(city="NYC", state="NY", tags=["home", "shipping"]), + ), + ) + + def test_contains(self): + self.assertCountEqual(Book.objects.filter(author__skills__contains=["nonexistent"]), []) + self.assertCountEqual( + Book.objects.filter(author__skills__contains=["writing"]), [self.book] + ) + # Nested + self.assertCountEqual( + Book.objects.filter(author__address__tags__contains=["nonexistent"]), [] + ) + self.assertCountEqual( + Book.objects.filter(author__address__tags__contains=["home"]), [self.book] + ) + + def test_contained_by(self): + self.assertCountEqual( + Book.objects.filter(author__skills__contained_by=["writing", "publishing"]), [] + ) + self.assertCountEqual( + Book.objects.filter(author__skills__contained_by=["writing", "editing", "publishing"]), + [self.book], + ) + # Nested + self.assertCountEqual( + Book.objects.filter(author__address__tags__contained_by=["home", "work"]), [] + ) + self.assertCountEqual( + Book.objects.filter(author__address__tags__contained_by=["home", "work", "shipping"]), + [self.book], + ) + + def test_len(self): + self.assertCountEqual(Book.objects.filter(author__skills__len=1), []) + self.assertCountEqual(Book.objects.filter(author__skills__len=2), [self.book]) + # Nested + self.assertCountEqual(Book.objects.filter(author__address__tags__len=1), []) + self.assertCountEqual(Book.objects.filter(author__address__tags__len=2), [self.book]) + + class InvalidLookupTests(SimpleTestCase): def test_invalid_field(self): msg = "Author has no field named 'first_name'"