Skip to content

Commit c2a8578

Browse files
committed
EmbeddedModelArrayField Subquerying
1 parent 4fafdeb commit c2a8578

File tree

3 files changed

+72
-2
lines changed

3 files changed

+72
-2
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,47 @@ def as_mql(self, compiler, connection):
141141
return {"$anyElementTrue": lhs_mql}
142142

143143

144+
class ArrayAggregationSubqueryMixin:
145+
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr):
146+
return [
147+
{
148+
"$facet": {
149+
"group": [
150+
{"$project": {"tmp_name": expr.as_mql(compiler, connection)}},
151+
{
152+
"$unwind": "$tmp_name",
153+
},
154+
{
155+
"$group": {
156+
"_id": None,
157+
"tmp_name": {"$addToSet": "$tmp_name"},
158+
}
159+
},
160+
]
161+
}
162+
},
163+
{
164+
"$project": {
165+
field_name: {
166+
"$ifNull": [
167+
{
168+
"$getField": {
169+
"input": {"$arrayElemAt": ["$group", 0]},
170+
"field": "tmp_name",
171+
}
172+
},
173+
[],
174+
]
175+
}
176+
}
177+
},
178+
]
179+
180+
144181
@_EmbeddedModelArrayOutputField.register_lookup
145-
class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In):
182+
class EmbeddedModelArrayFieldIn(
183+
EmbeddedModelArrayFieldBuiltinLookup, lookups.In, ArrayAggregationSubqueryMixin
184+
):
146185
pass
147186

148187

tests/model_fields_/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,11 @@ def __str__(self):
167167
return self.title
168168

169169

170+
class Audit(models.Model):
171+
related_section_number = models.IntegerField()
172+
reviewed = models.BooleanField()
173+
174+
170175
# An exhibit in the museum, composed of multiple sections.
171176
class Exhibit(models.Model):
172177
exhibit_name = models.CharField(max_length=255)

tests/model_fields_/test_embedded_model_array.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from django_mongodb_backend.fields import EmbeddedModelArrayField
99
from django_mongodb_backend.models import EmbeddedModel
1010

11-
from .models import Artifact, Exhibit, Movie, Restoration, Review, Section, Tour
11+
from .models import Artifact, Audit, Exhibit, Movie, Restoration, Review, Section, Tour
1212

1313

1414
class MethodTests(SimpleTestCase):
@@ -116,6 +116,7 @@ def setUpTestData(cls):
116116
],
117117
)
118118
],
119+
main_section=Section(section_number=2),
119120
)
120121
cls.lost_empires = Exhibit.objects.create(
121122
exhibit_name="Lost Empires",
@@ -146,6 +147,9 @@ def setUpTestData(cls):
146147
cls.egypt_tour = Tour.objects.create(guide="Amira", exhibit=cls.egypt)
147148
cls.wonders_tour = Tour.objects.create(guide="Carlos", exhibit=cls.wonders)
148149
cls.lost_tour = Tour.objects.create(guide="Yelena", exhibit=cls.lost_empires)
150+
cls.audit_1 = Audit.objects.create(related_section_number=1, reviewed=True)
151+
cls.audit_2 = Audit.objects.create(related_section_number=2, reviewed=True)
152+
cls.audit_3 = Audit.objects.create(related_section_number=5, reviewed=False)
149153

150154
def test_exact(self):
151155
self.assertCountEqual(
@@ -284,6 +288,28 @@ def test_foreign_field_with_slice(self):
284288
qs = Tour.objects.filter(exhibit__sections__0_2__section_number__in=[1, 2])
285289
self.assertCountEqual(qs, [self.wonders_tour, self.egypt_tour])
286290

291+
def test_subquery_section_number_lt(self):
292+
subq = Audit.objects.filter(
293+
related_section_number__in=models.OuterRef("sections__section_number")
294+
).values("related_section_number")[:1]
295+
self.assertCountEqual(
296+
Exhibit.objects.filter(sections__section_number=subq),
297+
[self.egypt, self.wonders, self.new_descoveries],
298+
)
299+
300+
def test_check_in_subquery(self):
301+
subquery = Audit.objects.filter(reviewed=True).values_list(
302+
"related_section_number", flat=True
303+
)
304+
result = Exhibit.objects.filter(sections__section_number__in=subquery)
305+
self.assertCountEqual(result, [self.wonders, self.egypt, self.new_descoveries])
306+
307+
def test_array_as_rhs(self):
308+
result = Exhibit.objects.filter(
309+
main_section__section_number__in=models.F("sections__section_number")
310+
)
311+
self.assertCountEqual(result, [self.new_descoveries])
312+
287313

288314
@isolate_apps("model_fields_")
289315
class CheckTests(SimpleTestCase):

0 commit comments

Comments
 (0)