Skip to content

Commit 3ddce45

Browse files
committed
Handle output type as a separate field
1 parent 2de42ec commit 3ddce45

File tree

2 files changed

+121
-81
lines changed

2 files changed

+121
-81
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 109 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,10 @@
88
from .. import forms
99
from ..query_utils import process_lhs, process_rhs
1010
from . import EmbeddedModelField
11-
from .array import ArrayField
11+
from .array import ArrayField, ArrayLenTransform
1212

1313

1414
class EmbeddedModelArrayField(ArrayField):
15-
ALLOWED_LOOKUPS = {
16-
"in",
17-
"exact",
18-
"iexact",
19-
"gt",
20-
"gte",
21-
"lt",
22-
"lte",
23-
"all",
24-
"contained_by",
25-
}
26-
2715
def __init__(self, embedded_model, **kwargs):
2816
if "size" in kwargs:
2917
raise ValueError("EmbeddedModelArrayField does not support size.")
@@ -69,18 +57,50 @@ def get_transform(self, name):
6957
return transform
7058
return KeyTransformFactory(name, self)
7159

60+
def _get_lookup(self, lookup_name):
61+
lookup = super()._get_lookup(lookup_name)
62+
if lookup is None or lookup is ArrayLenTransform:
63+
return lookup
64+
65+
class EmbeddedModelArrayFieldLookups(Lookup):
66+
def as_mql(self, compiler, connection):
67+
raise ValueError(
68+
"Cannot apply this lookup directly to EmbeddedModelArrayField. "
69+
"Try querying one of its embedded fields instead."
70+
)
71+
72+
return EmbeddedModelArrayFieldLookups
73+
74+
75+
class _EmbeddedModelArrayOutputField(ArrayField):
76+
"""
77+
Represents the output of an EmbeddedModelArrayField when traversed in a query path.
78+
79+
This field is not meant to be used directly in model definitions. It exists solely to
80+
support query output resolution; when an EmbeddedModelArrayField is accessed in a query,
81+
the result should behave like an array of the embedded model's target type.
82+
83+
While it mimics ArrayField's lookups behavior, the way those lookups are resolved
84+
follows the semantics of EmbeddedModelArrayField rather than native array behavior.
85+
"""
86+
87+
ALLOWED_LOOKUPS = {
88+
"in",
89+
"exact",
90+
"iexact",
91+
"gt",
92+
"gte",
93+
"lt",
94+
"lte",
95+
"all",
96+
"contained_by",
97+
}
98+
7299
def get_lookup(self, name):
73100
return super().get_lookup(name) if name in self.ALLOWED_LOOKUPS else None
74101

75102

76103
class EmbeddedModelArrayFieldBuiltinLookup(Lookup):
77-
def check_lhs(self):
78-
if not isinstance(self.lhs, KeyTransform):
79-
raise ValueError(
80-
"Cannot apply this lookup directly to EmbeddedModelArrayField. "
81-
"Try querying one of its embedded fields instead."
82-
)
83-
84104
def process_rhs(self, compiler, connection):
85105
value = self.rhs
86106
if not self.get_db_prep_lookup_value_is_iterable:
@@ -95,111 +115,114 @@ def process_rhs(self, compiler, connection):
95115
]
96116

97117
def as_mql(self, compiler, connection):
98-
self.check_lhs()
99118
# Querying a subfield within the array elements (via nested KeyTransform).
100119
# Replicates MongoDB's implicit ANY-match by mapping over the array and applying
101120
# `$in` on the subfield.
102-
lhs_mql, inner_lhs_mql = process_lhs(self, compiler, connection)
121+
lhs_mql = process_lhs(self, compiler, connection)
122+
inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"]
103123
values = process_rhs(self, compiler, connection)
104-
return {
105-
"$anyElementTrue": {
106-
"$ifNull": [
107-
{
108-
"$map": {
109-
"input": lhs_mql,
110-
"as": "item",
111-
"in": connection.mongo_operators[self.lookup_name](
112-
inner_lhs_mql, values
113-
),
114-
}
115-
},
116-
[],
117-
]
118-
}
119-
}
124+
lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators[self.lookup_name](
125+
inner_lhs_mql, values
126+
)
127+
return {"$anyElementTrue": lhs_mql}
120128

121129

122-
@EmbeddedModelArrayField.register_lookup
130+
@_EmbeddedModelArrayOutputField.register_lookup
123131
class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In):
124-
pass
132+
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr):
133+
return [
134+
{
135+
"$facet": {
136+
"group": [
137+
{"$project": {"tmp_name": expr.as_mql(compiler, connection)}},
138+
{
139+
"$unwind": "$tmp_name",
140+
},
141+
{
142+
"$group": {
143+
"_id": None,
144+
"tmp_name": {"$addToSet": "$tmp_name"},
145+
}
146+
},
147+
]
148+
}
149+
},
150+
{
151+
"$project": {
152+
field_name: {
153+
"$ifNull": [
154+
{
155+
"$getField": {
156+
"input": {"$arrayElemAt": ["$group", 0]},
157+
"field": "tmp_name",
158+
}
159+
},
160+
[],
161+
]
162+
}
163+
}
164+
},
165+
]
125166

126167

127-
@EmbeddedModelArrayField.register_lookup
168+
@_EmbeddedModelArrayOutputField.register_lookup
128169
class EmbeddedModelArrayFieldExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.Exact):
129170
pass
130171

131172

132-
@EmbeddedModelArrayField.register_lookup
173+
@_EmbeddedModelArrayOutputField.register_lookup
133174
class EmbeddedModelArrayFieldIExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.IExact):
134175
get_db_prep_lookup_value_is_iterable = False
135176

136177

137-
@EmbeddedModelArrayField.register_lookup
178+
@_EmbeddedModelArrayOutputField.register_lookup
138179
class EmbeddedModelArrayFieldGreaterThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThan):
139180
pass
140181

141182

142-
@EmbeddedModelArrayField.register_lookup
183+
@_EmbeddedModelArrayOutputField.register_lookup
143184
class EmbeddedModelArrayFieldGreaterThanOrEqual(
144185
EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThanOrEqual
145186
):
146187
pass
147188

148189

149-
@EmbeddedModelArrayField.register_lookup
190+
@_EmbeddedModelArrayOutputField.register_lookup
150191
class EmbeddedModelArrayFieldLessThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThan):
151192
pass
152193

153194

154-
@EmbeddedModelArrayField.register_lookup
195+
@_EmbeddedModelArrayOutputField.register_lookup
155196
class EmbeddedModelArrayFieldLessThanOrEqual(
156197
EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThanOrEqual
157198
):
158199
pass
159200

160201

161-
@EmbeddedModelArrayField.register_lookup
202+
@_EmbeddedModelArrayOutputField.register_lookup
162203
class EmbeddedModelArrayFieldAll(EmbeddedModelArrayFieldBuiltinLookup, Lookup):
163204
lookup_name = "all"
164205
get_db_prep_lookup_value_is_iterable = False
165206

166207
def as_mql(self, compiler, connection):
167-
self.check_lhs()
168-
lhs_mql, inner_lhs_mql = process_lhs(self, compiler, connection)
208+
lhs_mql = process_lhs(self, compiler, connection)
169209
values = process_rhs(self, compiler, connection)
170210
return {
171-
"$setIsSubset": [
172-
values,
173-
{
174-
"$ifNull": [
175-
{
176-
"$map": {
177-
"input": lhs_mql,
178-
"as": "item",
179-
"in": inner_lhs_mql,
180-
}
181-
},
182-
[],
183-
]
184-
},
211+
"$and": [
212+
{"$ne": [lhs_mql, None]},
213+
{"$ne": [values, None]},
214+
{"$setIsSubset": [values, lhs_mql]},
185215
]
186216
}
187217

188218

189-
@EmbeddedModelArrayField.register_lookup
219+
@_EmbeddedModelArrayOutputField.register_lookup
190220
class ArrayContainedBy(EmbeddedModelArrayFieldBuiltinLookup, Lookup):
191221
lookup_name = "contained_by"
192222
get_db_prep_lookup_value_is_iterable = False
193223

194224
def as_mql(self, compiler, connection):
195-
lhs_mql, inner_lhs_mql = process_lhs(self, compiler, connection)
196-
lhs_mql = {
197-
"$map": {
198-
"input": lhs_mql,
199-
"as": "item",
200-
"in": inner_lhs_mql,
201-
}
202-
}
225+
lhs_mql = process_lhs(self, compiler, connection)
203226
value = process_rhs(self, compiler, connection)
204227
return {
205228
"$and": [
@@ -244,7 +267,7 @@ def get_transform(self, name):
244267
self._sub_transform = transform
245268
return self
246269
output_field = self._lhs.output_field
247-
allowed_lookups = self.array_field.ALLOWED_LOOKUPS.intersection(
270+
allowed_lookups = self.output_field.ALLOWED_LOOKUPS.intersection(
248271
set(output_field.get_lookups())
249272
)
250273
suggested_lookups = difflib.get_close_matches(name, allowed_lookups)
@@ -262,11 +285,22 @@ def get_transform(self, name):
262285
def as_mql(self, compiler, connection):
263286
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
264287
lhs_mql = process_lhs(self, compiler, connection)
265-
return lhs_mql, inner_lhs_mql
288+
return {
289+
"$ifNull": [
290+
{
291+
"$map": {
292+
"input": lhs_mql,
293+
"as": "item",
294+
"in": inner_lhs_mql,
295+
}
296+
},
297+
[],
298+
]
299+
}
266300

267301
@property
268302
def output_field(self):
269-
return self.array_field
303+
return _EmbeddedModelArrayOutputField(self._lhs.output_field)
270304

271305

272306
class KeyTransformFactory:

tests/model_fields_/test_embedded_model_array.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_all_filter(self):
186186
def test_contained_by(self):
187187
self.assertCountEqual(
188188
MuseumExhibit.objects.filter(sections__section_number__contained_by=[1, 2, 3]),
189-
[self.egypt, self.new_descoveries, self.wonders],
189+
[self.egypt, self.new_descoveries, self.wonders, self.lost_empires],
190190
)
191191

192192
def test_len_filter(self):
@@ -258,12 +258,15 @@ def test_query_array_not_allowed(self):
258258
"Try querying one of its embedded fields instead."
259259
)
260260
with self.assertRaisesMessage(ValueError, msg):
261-
self.assertCountEqual(MuseumExhibit.objects.filter(sections=10), [])
261+
MuseumExhibit.objects.filter(sections=10).first()
262+
263+
with self.assertRaisesMessage(ValueError, msg):
264+
MuseumExhibit.objects.filter(sections__0_1=10).first()
262265

263266
def test_missing_field(self):
264267
msg = "ExhibitSection has no field named 'section'"
265268
with self.assertRaisesMessage(FieldDoesNotExist, msg):
266-
self.assertCountEqual(MuseumExhibit.objects.filter(sections__section__in=[10]), [])
269+
MuseumExhibit.objects.filter(sections__section__in=[10]).first()
267270

268271
def test_missing_lookup(self):
269272
msg = "Unsupported lookup 'return' for EmbeddedModelArrayField of 'IntegerField'"
@@ -273,9 +276,7 @@ def test_missing_lookup(self):
273276
def test_missing_operation(self):
274277
msg = "Unsupported lookup 'rage' for EmbeddedModelArrayField of 'IntegerField'"
275278
with self.assertRaisesMessage(FieldDoesNotExist, msg):
276-
self.assertCountEqual(
277-
MuseumExhibit.objects.filter(sections__section_number__rage=[10]), []
278-
)
279+
MuseumExhibit.objects.filter(sections__section_number__rage=[10])
279280

280281
def test_missing_lookup_suggestions(self):
281282
msg = (
@@ -290,6 +291,11 @@ def test_double_emfarray_transform(self):
290291
with self.assertRaisesMessage(ValueError, msg):
291292
MuseumExhibit.objects.filter(sections__artifacts__name="")
292293

294+
def test_slice(self):
295+
self.assertSequenceEqual(
296+
MuseumExhibit.objects.filter(sections__0_1__section_number=2), [self.new_descoveries]
297+
)
298+
293299

294300
@isolate_apps("model_fields_")
295301
class CheckTests(SimpleTestCase):

0 commit comments

Comments
 (0)