Skip to content

Commit 08ae4b4

Browse files
committed
add ArrayField
1 parent 0bd3b7e commit 08ae4b4

File tree

24 files changed

+2480
-8
lines changed

24 files changed

+2480
-8
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ repos:
4444
hooks:
4545
- id: rstcheck
4646
additional_dependencies: [sphinx]
47+
args: ["--ignore-directives=fieldlookup,setting", "--ignore-roles=lookup,setting"]
4748

4849
# We use the Python version instead of the original version which seems to require Docker
4950
# https://github.com/koalaman/shellcheck-precommit

django_mongodb_backend/features.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,15 @@ class DatabaseFeatures(BaseDatabaseFeatures):
8080
"auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key",
8181
# GenericRelation.value_to_string() assumes integer pk.
8282
"contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string",
83+
# contains with Exists() doesn't work:
84+
# https://github.com/mongodb-labs/django-mongodb/issues/204
85+
"model_fields_.test_arrayfield.QueryingTests.test_contains_subquery",
86+
# overlap with values() returns no results:
87+
# https://github.com/mongodb-labs/django-mongodb/issues/209
88+
"model_fields_.test_arrayfield.QueryingTests.test_overlap_values",
89+
# icontains doesn't work on ArrayField:
90+
# Unsupported conversion from array to string in $convert
91+
"model_fields_.test_arrayfield.QueryingTests.test_icontains",
8392
}
8493
# $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3.
8594
_django_test_expected_failures_bitwise = {

django_mongodb_backend/fields/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from .array import ArrayField
12
from .auto import ObjectIdAutoField
23
from .duration import register_duration_field
34
from .json import register_json_field
45
from .objectid import ObjectIdField
56

6-
__all__ = ["register_fields", "ObjectIdAutoField", "ObjectIdField"]
7+
__all__ = ["register_fields", "ArrayField", "ObjectIdAutoField", "ObjectIdField"]
78

89

910
def register_fields():
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
import json
2+
3+
from django.contrib.postgres.validators import ArrayMaxLengthValidator
4+
from django.core import checks, exceptions
5+
from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value
6+
from django.db.models.fields.mixins import CheckFieldDefaultMixin
7+
from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin, In, Lookup
8+
from django.utils.translation import gettext_lazy as _
9+
10+
from ..forms import SimpleArrayField
11+
from ..query_utils import process_lhs, process_rhs
12+
from ..utils import prefix_validation_error
13+
14+
__all__ = ["ArrayField"]
15+
16+
17+
class AttributeSetter:
18+
def __init__(self, name, value):
19+
setattr(self, name, value)
20+
21+
22+
class ArrayField(CheckFieldDefaultMixin, Field):
23+
empty_strings_allowed = False
24+
default_error_messages = {
25+
"item_invalid": _("Item %(nth)s in the array did not validate:"),
26+
"nested_array_mismatch": _("Nested arrays must have the same length."),
27+
}
28+
_default_hint = ("list", "[]")
29+
30+
def __init__(self, base_field, size=None, **kwargs):
31+
self.base_field = base_field
32+
self.size = size
33+
if self.size:
34+
self.default_validators = [
35+
*self.default_validators,
36+
ArrayMaxLengthValidator(self.size),
37+
]
38+
# For performance, only add a from_db_value() method if the base field
39+
# implements it.
40+
if hasattr(self.base_field, "from_db_value"):
41+
self.from_db_value = self._from_db_value
42+
super().__init__(**kwargs)
43+
44+
@property
45+
def model(self):
46+
try:
47+
return self.__dict__["model"]
48+
except KeyError:
49+
raise AttributeError(
50+
"'%s' object has no attribute 'model'" % self.__class__.__name__
51+
) from None
52+
53+
@model.setter
54+
def model(self, model):
55+
self.__dict__["model"] = model
56+
self.base_field.model = model
57+
58+
@classmethod
59+
def _choices_is_value(cls, value):
60+
return isinstance(value, list | tuple) or super()._choices_is_value(value)
61+
62+
def check(self, **kwargs):
63+
errors = super().check(**kwargs)
64+
if self.base_field.remote_field:
65+
errors.append(
66+
checks.Error(
67+
"Base field for array cannot be a related field.",
68+
obj=self,
69+
id="django_mongodb_backend.array.E002",
70+
)
71+
)
72+
else:
73+
base_checks = self.base_field.check()
74+
if base_checks:
75+
error_messages = "\n ".join(
76+
f"{base_check.msg} ({base_check.id})"
77+
for base_check in base_checks
78+
if isinstance(base_check, checks.Error)
79+
)
80+
if error_messages:
81+
errors.append(
82+
checks.Error(
83+
f"Base field for array has errors:\n {error_messages}",
84+
obj=self,
85+
id="django_mongodb_backend.array.E001",
86+
)
87+
)
88+
warning_messages = "\n ".join(
89+
f"{base_check.msg} ({base_check.id})"
90+
for base_check in base_checks
91+
if isinstance(base_check, checks.Warning)
92+
)
93+
if warning_messages:
94+
errors.append(
95+
checks.Warning(
96+
f"Base field for array has warnings:\n {warning_messages}",
97+
obj=self,
98+
id="django_mongodb_backend.array.W004",
99+
)
100+
)
101+
return errors
102+
103+
def set_attributes_from_name(self, name):
104+
super().set_attributes_from_name(name)
105+
self.base_field.set_attributes_from_name(name)
106+
107+
@property
108+
def description(self):
109+
return f"Array of {self.base_field.description}"
110+
111+
def db_type(self, connection):
112+
return "array"
113+
114+
def get_db_prep_value(self, value, connection, prepared=False):
115+
if isinstance(value, list | tuple):
116+
# Workaround for https://code.djangoproject.com/ticket/35982
117+
# (fixed in Django 5.2).
118+
if isinstance(self.base_field, DecimalField):
119+
return [self.base_field.get_db_prep_save(i, connection) for i in value]
120+
return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
121+
return value
122+
123+
def deconstruct(self):
124+
name, path, args, kwargs = super().deconstruct()
125+
if path == "django_mongodb_backend.fields.array.ArrayField":
126+
path = "django_mongodb_backend.fields.ArrayField"
127+
kwargs.update(
128+
{
129+
"base_field": self.base_field.clone(),
130+
"size": self.size,
131+
}
132+
)
133+
return name, path, args, kwargs
134+
135+
def to_python(self, value):
136+
if isinstance(value, str):
137+
# Assume value is being deserialized.
138+
vals = json.loads(value)
139+
value = [self.base_field.to_python(val) for val in vals]
140+
return value
141+
142+
def _from_db_value(self, value, expression, connection):
143+
if value is None:
144+
return value
145+
return [self.base_field.from_db_value(item, expression, connection) for item in value]
146+
147+
def value_to_string(self, obj):
148+
values = []
149+
vals = self.value_from_object(obj)
150+
base_field = self.base_field
151+
152+
for val in vals:
153+
if val is None:
154+
values.append(None)
155+
else:
156+
obj = AttributeSetter(base_field.attname, val)
157+
values.append(base_field.value_to_string(obj))
158+
return json.dumps(values)
159+
160+
def get_transform(self, name):
161+
transform = super().get_transform(name)
162+
if transform:
163+
return transform
164+
if "_" not in name:
165+
try:
166+
index = int(name)
167+
except ValueError:
168+
pass
169+
else:
170+
return IndexTransformFactory(index, self.base_field)
171+
try:
172+
start, end = name.split("_")
173+
start = int(start)
174+
end = int(end)
175+
except ValueError:
176+
pass
177+
else:
178+
return SliceTransformFactory(start, end)
179+
180+
def validate(self, value, model_instance):
181+
super().validate(value, model_instance)
182+
for index, part in enumerate(value):
183+
try:
184+
self.base_field.validate(part, model_instance)
185+
except exceptions.ValidationError as error:
186+
raise prefix_validation_error(
187+
error,
188+
prefix=self.error_messages["item_invalid"],
189+
code="item_invalid",
190+
params={"nth": index + 1},
191+
) from None
192+
if isinstance(self.base_field, ArrayField) and len({len(i) for i in value}) > 1:
193+
raise exceptions.ValidationError(
194+
self.error_messages["nested_array_mismatch"],
195+
code="nested_array_mismatch",
196+
)
197+
198+
def run_validators(self, value):
199+
super().run_validators(value)
200+
for index, part in enumerate(value):
201+
try:
202+
self.base_field.run_validators(part)
203+
except exceptions.ValidationError as error:
204+
raise prefix_validation_error(
205+
error,
206+
prefix=self.error_messages["item_invalid"],
207+
code="item_invalid",
208+
params={"nth": index + 1},
209+
) from None
210+
211+
def formfield(self, **kwargs):
212+
return super().formfield(
213+
**{
214+
"form_class": SimpleArrayField,
215+
"base_field": self.base_field.formfield(),
216+
"max_length": self.size,
217+
**kwargs,
218+
}
219+
)
220+
221+
222+
class Array(Func):
223+
def as_mql(self, compiler, connection):
224+
return [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]
225+
226+
227+
class ArrayRHSMixin:
228+
def __init__(self, lhs, rhs):
229+
if isinstance(rhs, tuple | list):
230+
expressions = []
231+
for value in rhs:
232+
if not hasattr(value, "resolve_expression"):
233+
field = lhs.output_field
234+
value = Value(field.base_field.get_prep_value(value))
235+
expressions.append(value)
236+
rhs = Array(*expressions)
237+
super().__init__(lhs, rhs)
238+
239+
240+
@ArrayField.register_lookup
241+
class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
242+
lookup_name = "contains"
243+
244+
def as_mql(self, compiler, connection):
245+
lhs_mql = process_lhs(self, compiler, connection)
246+
value = process_rhs(self, compiler, connection)
247+
return {"$and": [{"$ne": [lhs_mql, None]}, {"$setIsSubset": [value, lhs_mql]}]}
248+
249+
250+
@ArrayField.register_lookup
251+
class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
252+
lookup_name = "contained_by"
253+
254+
def as_mql(self, compiler, connection):
255+
lhs_mql = process_lhs(self, compiler, connection)
256+
value = process_rhs(self, compiler, connection)
257+
return {"$and": [{"$ne": [lhs_mql, None]}, {"$setIsSubset": [lhs_mql, value]}]}
258+
259+
260+
@ArrayField.register_lookup
261+
class ArrayExact(ArrayRHSMixin, Exact):
262+
pass
263+
264+
265+
@ArrayField.register_lookup
266+
class ArrayOverlap(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
267+
lookup_name = "overlap"
268+
269+
def as_mql(self, compiler, connection):
270+
lhs_mql = process_lhs(self, compiler, connection)
271+
value = process_rhs(self, compiler, connection)
272+
return {
273+
"$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}]
274+
}
275+
276+
277+
@ArrayField.register_lookup
278+
class ArrayLenTransform(Transform):
279+
lookup_name = "len"
280+
output_field = IntegerField()
281+
282+
def as_mql(self, compiler, connection):
283+
lhs_mql = process_lhs(self, compiler, connection)
284+
return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}}
285+
286+
287+
@ArrayField.register_lookup
288+
class ArrayInLookup(In):
289+
def get_prep_lookup(self):
290+
values = super().get_prep_lookup()
291+
if hasattr(values, "resolve_expression"):
292+
return values
293+
# process_rhs() expects hashable values, so convert lists to tuples.
294+
prepared_values = []
295+
for value in values:
296+
if hasattr(value, "resolve_expression"):
297+
prepared_values.append(value)
298+
else:
299+
prepared_values.append(tuple(value))
300+
return prepared_values
301+
302+
303+
class IndexTransform(Transform):
304+
def __init__(self, index, base_field, *args, **kwargs):
305+
super().__init__(*args, **kwargs)
306+
self.index = index
307+
self.base_field = base_field
308+
309+
def as_mql(self, compiler, connection):
310+
lhs_mql = process_lhs(self, compiler, connection)
311+
return {"$arrayElemAt": [lhs_mql, self.index]}
312+
313+
@property
314+
def output_field(self):
315+
return self.base_field
316+
317+
318+
class IndexTransformFactory:
319+
def __init__(self, index, base_field):
320+
self.index = index
321+
self.base_field = base_field
322+
323+
def __call__(self, *args, **kwargs):
324+
return IndexTransform(self.index, self.base_field, *args, **kwargs)
325+
326+
327+
class SliceTransform(Transform):
328+
def __init__(self, start, end, *args, **kwargs):
329+
super().__init__(*args, **kwargs)
330+
self.start = start
331+
self.end = end
332+
333+
def as_mql(self, compiler, connection):
334+
lhs_mql = process_lhs(self, compiler, connection)
335+
return {"$slice": [lhs_mql, self.start, self.end]}
336+
337+
338+
class SliceTransformFactory:
339+
def __init__(self, start, end):
340+
self.start = start
341+
self.end = end
342+
343+
def __call__(self, *args, **kwargs):
344+
return SliceTransform(self.start, self.end, *args, **kwargs)
Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1-
from .fields import ObjectIdField
1+
from .fields import ObjectIdField, SimpleArrayField, SplitArrayField, SplitArrayWidget
22

3-
__all__ = ["ObjectIdField"]
3+
__all__ = [
4+
"SimpleArrayField",
5+
"SplitArrayField",
6+
"SplitArrayWidget",
7+
"ObjectIdField",
8+
]

0 commit comments

Comments
 (0)