Skip to content

Commit 1476769

Browse files
committed
Allow querying an EmbeddedModelField by model instance
1 parent 8a57a06 commit 1476769

File tree

3 files changed

+86
-2
lines changed

3 files changed

+86
-2
lines changed

django_mongodb_backend/fields/embedded_model.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from django.core import checks
44
from django.core.exceptions import FieldDoesNotExist
55
from django.db import models
6+
from django.db.models import lookups
67
from django.db.models.fields.related import lazy_related_operation
78
from django.db.models.lookups import Transform
89

910
from .. import forms
11+
from ..query_utils import process_lhs, process_rhs
1012
from .json import build_json_mql_path
1113

1214

@@ -149,6 +151,41 @@ def formfield(self, **kwargs):
149151
)
150152

151153

154+
@EmbeddedModelField.register_lookup
155+
class EMFExact(lookups.Exact):
156+
def model_to_dict(self, instance):
157+
"""Return a dict containing the data in a model instance."""
158+
data = {}
159+
emf_data = {}
160+
for f in instance._meta.concrete_fields:
161+
value = f.value_from_object(instance)
162+
if isinstance(f, EmbeddedModelField):
163+
emf_data[f"{f.name}"] = self.model_to_dict(value)
164+
continue
165+
# Unless explicitly set, primary keys aren't included in embedded
166+
# models.
167+
if f.primary_key and value is None:
168+
continue
169+
data[f"{f.name}"] = value
170+
return data, emf_data
171+
172+
def as_mql(self, compiler, connection):
173+
lhs_mql = process_lhs(self, compiler, connection)
174+
value = process_rhs(self, compiler, connection)
175+
if isinstance(value, models.Model):
176+
value, emf_data = self.model_to_dict(value)
177+
prefix = self.lhs.as_mql(compiler, connection)
178+
conditions = [{"$eq": [f"{prefix}.{k}", v]} for k, v in value.items()]
179+
# TODO: more tests for this logic. Might not work for another
180+
# layer of embedding.
181+
while emf_data:
182+
for k, v in emf_data.items():
183+
v, emf_data = v
184+
conditions += [{"$eq": [f"{prefix}.{k}.{x}", y]} for x, y in v.items()]
185+
return {"$and": conditions}
186+
return connection.mongo_operators[self.lookup_name](lhs_mql, value)
187+
188+
152189
class KeyTransform(Transform):
153190
def __init__(self, key_name, ref_field, *args, **kwargs):
154191
super().__init__(*args, **kwargs)
@@ -193,7 +230,13 @@ def preprocess_lhs(self, compiler, connection):
193230
previous = previous.lhs
194231
mql = previous.as_mql(compiler, connection)
195232
# The first json_key_transform is the field name.
196-
embedded_key_transforms.append(json_key_transforms.pop(0))
233+
try:
234+
field_name = json_key_transforms.pop(0)
235+
except IndexError:
236+
# This is a lookup of the embedded model itself.
237+
pass
238+
else:
239+
embedded_key_transforms.append(field_name)
197240
return mql, embedded_key_transforms, json_key_transforms
198241

199242
def as_mql(self, compiler, connection):

docs/source/topics/embedded-models.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,12 @@ as relational fields. For example, to retrieve all customers who have an
5454
address with the city "New York"::
5555

5656
>>> Customer.objects.filter(address__city="New York")
57+
58+
You can also query using a model instance. Unlike a normal relational lookup
59+
which does the lookup by primary key, since embedded models typically don't
60+
have a primary key set, the query requires that every field match. For example,
61+
this query gives customers with addresses with the city "New York" and all
62+
other fields of address equal to their default (:attr:`Field.default
63+
<django.db.models.Field.default>`, ``None``, or an empty string).
64+
65+
>>> Customer.objects.filter(address=Address(city="New York"))

tests/model_fields_/test_embedded_model.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from datetime import timedelta
33

44
from django.core.exceptions import FieldDoesNotExist, ValidationError
5-
from django.db import models
5+
from django.db import connection, models
66
from django.db.models import (
77
Exists,
88
ExpressionWrapper,
@@ -117,6 +117,38 @@ def test_order_by_embedded_field(self):
117117
qs = Holder.objects.filter(data__integer__gt=3).order_by("-data__integer")
118118
self.assertSequenceEqual(qs, list(reversed(self.objs[4:])))
119119

120+
def test_exact_with_model(self):
121+
data = Holder.objects.first().data
122+
self.assertEqual(
123+
Holder.objects.filter(data=data).get().data.integer, self.objs[0].data.integer
124+
)
125+
126+
def test_exact_with_model_ignores_key_order(self):
127+
# Due to the possibility of schema changes or the reordering of a
128+
# model's fields, a lookup must work if an embedded document has its
129+
# keys in a different order than what's declared on the embedded model.
130+
connection.get_collection("model_fields__holder").insert_one(
131+
{
132+
"data": {
133+
"auto_now": None,
134+
"auto_now_add": None,
135+
"json_value": None,
136+
"integer": 100,
137+
}
138+
}
139+
)
140+
self.assertEqual(Holder.objects.filter(data=Data(integer=100)).get().data.integer, 100)
141+
142+
def test_exact_with_nested_model(self):
143+
address = Address(city="NYC", state="NY")
144+
obj = Book.objects.create(author=Author(name="Shakespeare", age=55, address=address))
145+
self.assertCountEqual(Book.objects.filter(author__address=address), [obj])
146+
147+
def test_exact_with_model_with_embedded_modelt(self):
148+
author = Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY"))
149+
obj = Book.objects.create(author=author)
150+
self.assertCountEqual(Book.objects.filter(author=author), [obj])
151+
120152
def test_embedded_json_field_lookups(self):
121153
objs = [
122154
Holder.objects.create(

0 commit comments

Comments
 (0)