diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 57bbd3f50..ace077696 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -3,10 +3,13 @@ from django.core import checks from django.core.exceptions import FieldDoesNotExist from django.db import models +from django.db.models import lookups +from django.db.models.expressions import Col from django.db.models.fields.related import lazy_related_operation from django.db.models.lookups import Transform from .. import forms +from ..query_utils import process_lhs, process_rhs class EmbeddedModelField(models.Field): @@ -151,6 +154,67 @@ def formfield(self, **kwargs): ) +@EmbeddedModelField.register_lookup +class EMFExact(lookups.Exact): + def model_to_dict(self, instance, connection): + """ + Return a dict containing the data in a model instance, as well as a + dict containing the data for any embedded model fields. + """ + data = {} + emf_data = {} + for f in instance._meta.concrete_fields: + value = f.get_db_prep_value(f.value_from_object(instance), connection) + if isinstance(f, EmbeddedModelField): + emf_data[f.name] = ( + self.model_to_dict(value, connection) if value is not None else (None, {}) + ) + continue + # Unless explicitly set, primary keys aren't included in embedded + # models. + if f.primary_key and value is None: + continue + data[f.name] = value + return data, emf_data + + def get_conditions(self, emf_data, prefix=None): + """ + Recursively transform a dictionary of {"field_name": {}} + lookups into MQL. `prefix` tracks the string that must be appended to + nested fields. + """ + conditions = [] + for k, v in emf_data.items(): + v, emf_data = v + subprefix = f"{prefix}.{k}" if prefix else k + conditions += self.get_conditions(emf_data, subprefix) + if v is not None: + # Match all fields of the EmbeddedModelField. + conditions += [{"$eq": [f"{subprefix}.{x}", y]} for x, y in v.items()] + else: + # Match a null EmbeddedModelField. + conditions += [{"$eq": [f"{subprefix}", None]}] + return conditions + + def as_mql(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + value = process_rhs(self, compiler, connection) + if isinstance(self.lhs, Col) or ( + isinstance(self.lhs, KeyTransform) + and isinstance(self.lhs.ref_field, EmbeddedModelField) + ): + if isinstance(value, models.Model): + value, emf_data = self.model_to_dict(value, connection) + # Get conditions for any nested EmbeddedModelFields. + conditions = self.get_conditions({lhs_mql: (value, emf_data)}) + return {"$and": conditions} + raise TypeError( + "An EmbeddedModelField must be queried using a model instance, got %s." + % type(value) + ) + return connection.mongo_operators[self.lookup_name](lhs_mql, value) + + class KeyTransform(Transform): def __init__(self, key_name, ref_field, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/docs/source/releases/5.2.x.rst b/docs/source/releases/5.2.x.rst index edad165d0..4cf6a4bdc 100644 --- a/docs/source/releases/5.2.x.rst +++ b/docs/source/releases/5.2.x.rst @@ -22,6 +22,7 @@ New features a model's :attr:`Meta.indexes `. - PyMongo's connection pooling is now used by default. See :ref:`connection-management`. +- Allowed ``EmbeddedModelField``’s ``exact`` lookup to use a model instance. Backwards incompatible changes ------------------------------ diff --git a/docs/source/topics/embedded-models.rst b/docs/source/topics/embedded-models.rst index 94abecfd2..828c49f47 100644 --- a/docs/source/topics/embedded-models.rst +++ b/docs/source/topics/embedded-models.rst @@ -54,3 +54,16 @@ as relational fields. For example, to retrieve all customers who have an address with the city "New York":: >>> Customer.objects.filter(address__city="New York") + +You can also query using a model instance. Unlike a normal relational lookup +which does the lookup by primary key, since embedded models typically don't +have a primary key set, the query requires that every field match. For example, +this query gives customers with addresses with the city "New York" and all +other fields of the address equal to their default (:attr:`Field.default +`, ``None``, or an empty string). + + >>> Customer.objects.filter(address=Address(city="New York")) + +.. versionadded:: 5.2.0b0 + + The ability to query by model instance was added. diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index 2470f4bb8..36b78aee8 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -138,3 +138,31 @@ class Library(models.Model): def __str__(self): return self.name + + +class A(models.Model): + b = EmbeddedModelField("B") + + +class B(EmbeddedModel): + c = EmbeddedModelField("C") + name = models.CharField(max_length=100) + value = models.IntegerField() + + +class C(EmbeddedModel): + d = EmbeddedModelField("D") + name = models.CharField(max_length=100) + value = models.IntegerField() + + +class D(EmbeddedModel): + e = EmbeddedModelField("E") + nullable_e = EmbeddedModelField("E", null=True, blank=True) + name = models.CharField(max_length=100) + value = models.IntegerField() + + +class E(EmbeddedModel): + name = models.CharField(max_length=100) + value = models.IntegerField() diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index ec9f9dfc4..e2b3bd302 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -2,7 +2,7 @@ from datetime import timedelta from django.core.exceptions import FieldDoesNotExist, ValidationError -from django.db import models +from django.db import connection, models from django.db.models import ( Exists, ExpressionWrapper, @@ -17,15 +17,7 @@ from django_mongodb_backend.fields import EmbeddedModelField from django_mongodb_backend.models import EmbeddedModel -from .models import ( - Address, - Author, - Book, - Data, - Holder, - Library, - NestedData, -) +from .models import A, Address, Author, B, Book, C, D, Data, E, Holder, Library, NestedData from .utils import truncate_ms @@ -145,6 +137,62 @@ def test_order_by_embedded_field(self): qs = Holder.objects.filter(data__integer__gt=3).order_by("-data__integer") self.assertSequenceEqual(qs, list(reversed(self.objs[4:]))) + def test_exact_with_model(self): + data = Holder.objects.first().data + self.assertEqual( + Holder.objects.filter(data=data).get().data.integer, self.objs[0].data.integer + ) + + def test_exact_with_model_ignores_key_order(self): + # Due to the possibility of schema changes or the reordering of a + # model's fields, a lookup must work if an embedded document has its + # keys in a different order than what's declared on the embedded model. + data = {} + for field in reversed(Data._meta.fields): + data[field.name] = None + del data["id"] + data["integer"] = 100 + connection.get_collection("model_fields__holder").insert_one({"data": data}) + self.assertEqual(Holder.objects.filter(data=Data(integer=100)).get().data.integer, 100) + + def test_exact_with_nested_model(self): + address = Address(city="NYC", state="NY") + author = Author(name="Shakespeare", age=55, address=address) + obj = Book.objects.create(author=author) + self.assertCountEqual(Book.objects.filter(author=author), [obj]) + self.assertCountEqual(Book.objects.filter(author__address=address), [obj]) + + def test_exact_with_deeply_nested_models(self): + e1 = E(name="E1", value=5) + d1 = D(name="D1", value=4, e=e1) + c1 = C(name="C1", value=3, d=d1) + b1 = B(name="B1", value=2, c=c1) + a1 = A.objects.create(b=b1) + e2 = E(name="E2", value=6) + d2 = D(name="D2", value=4, e=e1, nullable_e=e2) + c2 = C(name="C2", value=3, d=d2) + b2 = B(name="B2", value=2, c=c2) + a2 = A.objects.create(b=b2) + self.assertCountEqual(A.objects.filter(b=b1), [a1]) + self.assertCountEqual(A.objects.filter(b__c=c1), [a1]) + self.assertCountEqual(A.objects.filter(b__c__d=d1), [a1]) + self.assertCountEqual(A.objects.filter(b__c__d__e=e1), [a1, a2]) + self.assertCountEqual(A.objects.filter(b=b2), [a2]) + self.assertCountEqual(A.objects.filter(b__c=c2), [a2]) + self.assertCountEqual(A.objects.filter(b__c__d=d2), [a2]) + self.assertCountEqual(A.objects.filter(b__c__d__nullable_e=e2), [a2]) + + def test_exact_validates_argument(self): + msg = "An EmbeddedModelField must be queried using a model instance, got ." + with self.assertRaisesMessage(TypeError, msg): + str(A.objects.filter(b={})) + with self.assertRaisesMessage(TypeError, msg): + str(A.objects.filter(b__c={})) + with self.assertRaisesMessage(TypeError, msg): + str(A.objects.filter(b__c__d={})) + with self.assertRaisesMessage(TypeError, msg): + str(A.objects.filter(b__c__d__e={})) + def test_embedded_json_field_lookups(self): objs = [ Holder.objects.create(