Skip to content

Commit 9bfc500

Browse files
committed
Add get_subquery_wrapping_pipeline and unit test
1 parent e40bea8 commit 9bfc500

File tree

3 files changed

+108
-1
lines changed

3 files changed

+108
-1
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,40 @@ def as_mql(self, compiler, connection):
129129

130130
@_EmbeddedModelArrayOutputField.register_lookup
131131
class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In):
132-
pass
132+
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr):
133+
return [
134+
{
135+
"$facet": {
136+
"group": [
137+
{"$project": {"tmp_name": expr.as_mql(compiler, connection)}},
138+
{
139+
"$unwind": "$tmp_name",
140+
},
141+
{
142+
"$group": {
143+
"_id": None,
144+
"tmp_name": {"$addToSet": "$tmp_name"},
145+
}
146+
},
147+
]
148+
}
149+
},
150+
{
151+
"$project": {
152+
field_name: {
153+
"$ifNull": [
154+
{
155+
"$getField": {
156+
"input": {"$arrayElemAt": ["$group", 0]},
157+
"field": "tmp_name",
158+
}
159+
},
160+
[],
161+
]
162+
}
163+
}
164+
},
165+
]
133166

134167

135168
@_EmbeddedModelArrayOutputField.register_lookup
@@ -171,6 +204,41 @@ class EmbeddedModelArrayFieldAll(EmbeddedModelArrayFieldBuiltinLookup, Lookup):
171204
lookup_name = "all"
172205
get_db_prep_lookup_value_is_iterable = False
173206

207+
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr):
208+
return [
209+
{
210+
"$facet": {
211+
"group": [
212+
{"$project": {"tmp_name": expr.as_mql(compiler, connection)}},
213+
{
214+
"$unwind": "$tmp_name",
215+
},
216+
{
217+
"$group": {
218+
"_id": None,
219+
"tmp_name": {"$addToSet": "$tmp_name"},
220+
}
221+
},
222+
]
223+
}
224+
},
225+
{
226+
"$project": {
227+
field_name: {
228+
"$ifNull": [
229+
{
230+
"$getField": {
231+
"input": {"$arrayElemAt": ["$group", 0]},
232+
"field": "tmp_name",
233+
}
234+
},
235+
[],
236+
]
237+
}
238+
}
239+
},
240+
]
241+
174242
def as_mql(self, compiler, connection):
175243
lhs_mql = process_lhs(self, compiler, connection)
176244
values = process_rhs(self, compiler, connection)

tests/model_fields_/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,11 @@ class ArtifactDetail(EmbeddedModel):
180180
last_restoration = EmbeddedModelField(RestorationRecord, null=True)
181181

182182

183+
class ExhibitAudit(models.Model):
184+
related_section_number = models.IntegerField()
185+
reviewed = models.BooleanField()
186+
187+
183188
# A section within an exhibit, containing multiple artifacts.
184189
class ExhibitSection(EmbeddedModel):
185190
section_number = models.IntegerField()

tests/model_fields_/test_embedded_model_array.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from .models import (
1212
ArtifactDetail,
13+
ExhibitAudit,
1314
ExhibitSection,
1415
Movie,
1516
MuseumExhibit,
@@ -124,6 +125,7 @@ def setUpTestData(cls):
124125
],
125126
)
126127
],
128+
main_section=ExhibitSection(section_number=2),
127129
)
128130
cls.lost_empires = MuseumExhibit.objects.create(
129131
exhibit_name="Lost Empires",
@@ -154,6 +156,9 @@ def setUpTestData(cls):
154156
cls.egypt_tour = Tour.objects.create(guide="Amira", exhibit=cls.egypt)
155157
cls.wonders_tour = Tour.objects.create(guide="Carlos", exhibit=cls.wonders)
156158
cls.lost_tour = Tour.objects.create(guide="Yelena", exhibit=cls.lost_empires)
159+
cls.audit_1 = ExhibitAudit.objects.create(related_section_number=1, reviewed=True)
160+
cls.audit_2 = ExhibitAudit.objects.create(related_section_number=2, reviewed=True)
161+
cls.audit_3 = ExhibitAudit.objects.create(related_section_number=5, reviewed=False)
157162

158163
def test_filter_with_field(self):
159164
self.assertCountEqual(
@@ -309,6 +314,35 @@ def test_foreign_field_with_slice(self):
309314
qs = Tour.objects.filter(exhibit__sections__0_2__section_number__all=[1, 2])
310315
self.assertEqual(list(qs), [self.wonders_tour])
311316

317+
def test_subquery_section_number_lt(self):
318+
subq = ExhibitAudit.objects.filter(
319+
related_section_number__in=models.OuterRef("sections__section_number")
320+
).values("related_section_number")[:1]
321+
self.assertCountEqual(
322+
MuseumExhibit.objects.filter(sections__section_number=subq),
323+
[self.egypt, self.wonders, self.new_descoveries],
324+
)
325+
326+
def test_check_all_subquery(self):
327+
subquery = ExhibitAudit.objects.filter(reviewed=True).values_list(
328+
"related_section_number", flat=True
329+
)
330+
result = MuseumExhibit.objects.filter(sections__section_number__all=subquery)
331+
self.assertCountEqual(result, [self.wonders])
332+
333+
def test_check_in_subquery(self):
334+
subquery = ExhibitAudit.objects.filter(reviewed=True).values_list(
335+
"related_section_number", flat=True
336+
)
337+
result = MuseumExhibit.objects.filter(sections__section_number__in=subquery)
338+
self.assertCountEqual(result, [self.wonders, self.egypt, self.new_descoveries])
339+
340+
def test_array_as_rhs(self):
341+
result = MuseumExhibit.objects.filter(
342+
main_section__section_number__in=models.F("sections__section_number")
343+
)
344+
self.assertCountEqual(result, [self.new_descoveries])
345+
312346

313347
@isolate_apps("model_fields_")
314348
class CheckTests(SimpleTestCase):

0 commit comments

Comments
 (0)