Skip to content

Commit 039af42

Browse files
committed
add overlap lookup
1 parent 1e393c4 commit 039af42

File tree

4 files changed

+74
-0
lines changed

4 files changed

+74
-0
lines changed

django_mongodb/features.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
8383
# contains with Exists() doesn't work:
8484
# https://github.com/mongodb-labs/django-mongodb/issues/204
8585
"model_fields_.test_arrayfield.QueryingTests.test_contains_subquery",
86+
# overlap with values() returns no results:
87+
# https://github.com/mongodb-labs/django-mongodb/issues/209
88+
"model_fields_.test_arrayfield.QueryingTests.test_overlap_values",
8689
# icontains doesn't work on ArrayField:
8790
# Unsupported conversion from array to string in $convert
8891
"model_fields_.test_arrayfield.QueryingTests.test_icontains",

django_mongodb/fields/array.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,18 @@ class ArrayExact(ArrayRHSMixin, Exact):
253253
pass
254254

255255

256+
@ArrayField.register_lookup
257+
class ArrayOverlap(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
258+
lookup_name = "overlap"
259+
260+
def as_mql(self, compiler, connection):
261+
lhs_mql = process_lhs(self, compiler, connection)
262+
value = process_rhs(self, compiler, connection)
263+
return {
264+
"$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}]
265+
}
266+
267+
256268
@ArrayField.register_lookup
257269
class ArrayLenTransform(Transform):
258270
lookup_name = "len"

docs/source/fields.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,26 @@ data. It uses the ``$setIntersection`` operator. For example:
106106
>>> Post.objects.filter(tags__contains=["django", "thoughts"])
107107
<QuerySet [<Post: First post>]>
108108
109+
.. fieldlookup:: arrayfield.overlap
110+
111+
``overlap``
112+
~~~~~~~~~~~
113+
114+
Returns objects where the data shares any results with the values passed. It
115+
uses the ``$setIntersection`` operator. For example:
116+
117+
.. code-block:: pycon
118+
119+
>>> Post.objects.create(name="First post", tags=["thoughts", "django"])
120+
>>> Post.objects.create(name="Second post", tags=["thoughts", "tutorial"])
121+
>>> Post.objects.create(name="Third post", tags=["tutorial", "django"])
122+
123+
>>> Post.objects.filter(tags__overlap=["thoughts"])
124+
<QuerySet [<Post: First post>, <Post: Second post>]>
125+
126+
>>> Post.objects.filter(tags__overlap=["thoughts", "tutorial"])
127+
<QuerySet [<Post: First post>, <Post: Second post>, <Post: Third post>]>
128+
109129
.. fieldlookup:: arrayfield.len
110130

111131
``len``

tests/model_fields_/test_arrayfield.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from django.core.management import call_command
1111
from django.db import IntegrityError, connection, models
1212
from django.db.models.expressions import Exists, OuterRef, Value
13+
from django.db.models.functions import Upper
1314
from django.test import (
1415
SimpleTestCase,
1516
TestCase,
@@ -369,6 +370,38 @@ def test_icontains(self):
369370
def test_contains_charfield(self):
370371
self.assertSequenceEqual(CharArrayModel.objects.filter(field__contains=["text"]), [])
371372

373+
def test_overlap_charfield(self):
374+
self.assertSequenceEqual(CharArrayModel.objects.filter(field__overlap=["text"]), [])
375+
376+
def test_overlap_charfield_including_expression(self):
377+
obj_1 = CharArrayModel.objects.create(field=["TEXT", "lower text"])
378+
obj_2 = CharArrayModel.objects.create(field=["lower text", "TEXT"])
379+
CharArrayModel.objects.create(field=["lower text", "text"])
380+
self.assertSequenceEqual(
381+
CharArrayModel.objects.filter(
382+
field__overlap=[
383+
Upper(Value("text")),
384+
"other",
385+
]
386+
),
387+
[obj_1, obj_2],
388+
)
389+
390+
def test_overlap_values(self):
391+
qs = NullableIntegerArrayModel.objects.filter(order__lt=3)
392+
self.assertCountEqual(
393+
NullableIntegerArrayModel.objects.filter(
394+
field__overlap=qs.values_list("field"),
395+
),
396+
self.objs[:3],
397+
)
398+
self.assertCountEqual(
399+
NullableIntegerArrayModel.objects.filter(
400+
field__overlap=qs.values("field"),
401+
),
402+
self.objs[:3],
403+
)
404+
372405
def test_index(self):
373406
self.assertSequenceEqual(
374407
NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3]
@@ -389,6 +422,12 @@ def test_index_used_on_nested_data(self):
389422
NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), [instance]
390423
)
391424

425+
def test_overlap(self):
426+
self.assertSequenceEqual(
427+
NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]),
428+
self.objs[0:3],
429+
)
430+
392431
def test_index_annotation(self):
393432
qs = NullableIntegerArrayModel.objects.annotate(second=models.F("field__1"))
394433
self.assertCountEqual(

0 commit comments

Comments
 (0)