Skip to content

Commit 0f35e77

Browse files
committed
Allow querying an EmbeddedModelField by model instance
1 parent 8a57a06 commit 0f35e77

File tree

3 files changed

+70
-2
lines changed

3 files changed

+70
-2
lines changed

django_mongodb_backend/fields/embedded_model.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from django.core import checks
44
from django.core.exceptions import FieldDoesNotExist
55
from django.db import models
6+
from django.db.models import lookups
67
from django.db.models.fields.related import lazy_related_operation
78
from django.db.models.lookups import Transform
89

910
from .. import forms
11+
from ..query_utils import process_lhs, process_rhs
1012
from .json import build_json_mql_path
1113

1214

@@ -149,6 +151,30 @@ def formfield(self, **kwargs):
149151
)
150152

151153

154+
@EmbeddedModelField.register_lookup
155+
class EMFExact(lookups.Exact):
156+
def model_to_dict(self, instance):
157+
"""Return a dict containing the data in a model instance."""
158+
data = {}
159+
for f in instance._meta.concrete_fields:
160+
value = f.value_from_object(instance)
161+
# Unless explicitly set, primary keys aren't included in embedded
162+
# models.
163+
if f.primary_key and value is None:
164+
continue
165+
data[f"{f.name}"] = value
166+
return data
167+
168+
def as_mql(self, compiler, connection):
169+
lhs_mql = process_lhs(self, compiler, connection)
170+
value = process_rhs(self, compiler, connection)
171+
if isinstance(value, models.Model):
172+
value = self.model_to_dict(value)
173+
prefix = self.lhs.as_mql(compiler, connection)
174+
return {"$and": [{"$eq": [f"{prefix}.{k}", v]} for k, v in value.items()]}
175+
return connection.mongo_operators[self.lookup_name](lhs_mql, value)
176+
177+
152178
class KeyTransform(Transform):
153179
def __init__(self, key_name, ref_field, *args, **kwargs):
154180
super().__init__(*args, **kwargs)
@@ -193,7 +219,13 @@ def preprocess_lhs(self, compiler, connection):
193219
previous = previous.lhs
194220
mql = previous.as_mql(compiler, connection)
195221
# The first json_key_transform is the field name.
196-
embedded_key_transforms.append(json_key_transforms.pop(0))
222+
try:
223+
field_name = json_key_transforms.pop(0)
224+
except IndexError:
225+
# This is a lookup of the embedded model itself.
226+
pass
227+
else:
228+
embedded_key_transforms.append(field_name)
197229
return mql, embedded_key_transforms, json_key_transforms
198230

199231
def as_mql(self, compiler, connection):

docs/source/topics/embedded-models.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,12 @@ as relational fields. For example, to retrieve all customers who have an
5454
address with the city "New York"::
5555

5656
>>> Customer.objects.filter(address__city="New York")
57+
58+
You can also query using a model instance. Unlike a normal relational lookup
59+
which does the lookup by primary key, since embedded models typically don't
60+
have a primary key set, the query requires that every field match. For example,
61+
this query gives customers with addresses with the city "New York" and all
62+
other fields of address equal to their default (:attr:`Field.default
63+
<django.db.models.Field.default>`, ``None``, or an empty string).
64+
65+
>>> Customer.objects.filter(address=Address(city="New York"))

tests/model_fields_/test_embedded_model.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from datetime import timedelta
33

44
from django.core.exceptions import FieldDoesNotExist, ValidationError
5-
from django.db import models
5+
from django.db import connection, models
66
from django.db.models import (
77
Exists,
88
ExpressionWrapper,
@@ -117,6 +117,33 @@ def test_order_by_embedded_field(self):
117117
qs = Holder.objects.filter(data__integer__gt=3).order_by("-data__integer")
118118
self.assertSequenceEqual(qs, list(reversed(self.objs[4:])))
119119

120+
def test_exact_with_model(self):
121+
data = Holder.objects.first().data
122+
self.assertEqual(
123+
Holder.objects.filter(data=data).get().data.integer, self.objs[0].data.integer
124+
)
125+
126+
def test_exact_with_model_ignores_key_order(self):
127+
# Due to the possibility of schema changes or the reordering of a
128+
# model's fields, a lookup must work if an embedded document has its
129+
# keys in a different order than what's declared on the embedded model.
130+
connection.get_collection("model_fields__holder").insert_one(
131+
{
132+
"data": {
133+
"auto_now": None,
134+
"auto_now_add": None,
135+
"json_value": None,
136+
"integer": 100,
137+
}
138+
}
139+
)
140+
self.assertEqual(Holder.objects.filter(data=Data(integer=100)).get().data.integer, 100)
141+
142+
def test_exact_with_nested_model(self):
143+
address = Address(city="NYC", state="NY")
144+
obj = Book.objects.create(author=Author(name="Shakespeare", age=55, address=address))
145+
self.assertCountEqual(Book.objects.filter(author__address=address), [obj])
146+
120147
def test_embedded_json_field_lookups(self):
121148
objs = [
122149
Holder.objects.create(

0 commit comments

Comments
 (0)