Skip to content

Commit dd25fa8

Browse files
WaVEVtimgraham
authored andcommitted
add ObjectIdField
1 parent 36c5718 commit dd25fa8

File tree

8 files changed

+273
-13
lines changed

8 files changed

+273
-13
lines changed

django_mongodb/fields/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .auto import ObjectIdAutoField
22
from .duration import register_duration_field
33
from .json import register_json_field
4+
from .objectid import ObjectIdField
45

5-
__all__ = ["register_fields", "ObjectIdAutoField"]
6+
__all__ = ["register_fields", "ObjectIdAutoField", "ObjectIdField"]
67

78

89
def register_fields():

django_mongodb/fields/auto.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,11 @@
22
from django.core import exceptions
33
from django.db.models.fields import AutoField
44
from django.utils.functional import cached_property
5-
from django.utils.translation import gettext_lazy as _
65

6+
from .objectid import ObjectIdMixin
77

8-
class ObjectIdAutoField(AutoField):
9-
default_error_messages = {
10-
"invalid": _("“%(value)s” value must be an Object Id."),
11-
}
12-
description = _("Object Id")
138

9+
class ObjectIdAutoField(ObjectIdMixin, AutoField):
1410
def __init__(self, *args, **kwargs):
1511
kwargs["db_column"] = "_id"
1612
super().__init__(*args, **kwargs)
@@ -42,12 +38,6 @@ def get_prep_value(self, value):
4238
def get_internal_type(self):
4339
return "ObjectIdAutoField"
4440

45-
def db_type(self, connection):
46-
return "objectId"
47-
48-
def rel_db_type(self, connection):
49-
return "objectId"
50-
5141
def to_python(self, value):
5242
if value is None or isinstance(value, int):
5343
return value

django_mongodb/fields/objectid.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from bson import ObjectId, errors
2+
from django.core import exceptions
3+
from django.db.models.fields import Field
4+
from django.utils.translation import gettext_lazy as _
5+
6+
7+
class ObjectIdMixin:
8+
default_error_messages = {
9+
"invalid": _("“%(value)s” value must be an Object Id."),
10+
}
11+
description = _("Object Id")
12+
13+
def db_type(self, connection):
14+
return "objectId"
15+
16+
def rel_db_type(self, connection):
17+
return "objectId"
18+
19+
def get_prep_value(self, value):
20+
if value is None or isinstance(value, ObjectId):
21+
return value
22+
try:
23+
return ObjectId(value)
24+
except (errors.InvalidId, TypeError) as e:
25+
raise ValueError(f"Field '{self.name}' expected an ObjectId but got {value!r}.") from e
26+
27+
def to_python(self, value):
28+
if value is None:
29+
return value
30+
try:
31+
return ObjectId(value)
32+
except (errors.InvalidId, TypeError):
33+
raise exceptions.ValidationError(
34+
self.error_messages["invalid"],
35+
code="invalid",
36+
params={"value": value},
37+
) from None
38+
39+
40+
class ObjectIdField(ObjectIdMixin, Field):
41+
def deconstruct(self):
42+
name, path, args, kwargs = super().deconstruct()
43+
if path.startswith("django_mongodb.fields.objectid"):
44+
path = path.replace("django_mongodb.fields.objectid", "django_mongodb.fields")
45+
return name, path, args, kwargs
46+
47+
def get_internal_type(self):
48+
return "ObjectIdField"

docs/source/fields.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
Model field reference
2+
=====================
3+
4+
.. module:: django_mongodb.fields
5+
6+
Some MongoDB-specific fields are available in ``django_mongodb.fields``.
7+
8+
``ObjectIdField``
9+
-----------------
10+
11+
.. class:: ObjectIdField
12+
13+
Stores an :class:`~bson.objectid.ObjectId`.

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ django-mongodb 5.0.x documentation
55
:maxdepth: 1
66
:caption: Contents:
77

8+
fields
89
querysets
910

1011
Indices and tables
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from bson import ObjectId
2+
from django.core import exceptions
3+
from django.test import SimpleTestCase
4+
5+
from django_mongodb.fields import ObjectIdField
6+
7+
8+
class MethodTests(SimpleTestCase):
9+
def test_deconstruct(self):
10+
field = ObjectIdField()
11+
name, path, args, kwargs = field.deconstruct()
12+
self.assertEqual(path, "django_mongodb.fields.ObjectIdField")
13+
self.assertEqual(args, [])
14+
self.assertEqual(kwargs, {})
15+
16+
def test_get_internal_type(self):
17+
f = ObjectIdField()
18+
self.assertEqual(f.get_internal_type(), "ObjectIdField")
19+
20+
def test_to_python_string(self):
21+
value = "1" * 24
22+
self.assertEqual(ObjectIdField().to_python(value), ObjectId(value))
23+
24+
def test_to_python_objectid(self):
25+
value = ObjectId("1" * 24)
26+
self.assertEqual(ObjectIdField().to_python(value), value)
27+
28+
def test_to_python_null(self):
29+
self.assertIsNone(ObjectIdField().to_python(None))
30+
31+
def test_to_python_invalid_value(self):
32+
f = ObjectIdField()
33+
for invalid_value in ["None", {}, [], 123]:
34+
with self.subTest(invalid_value=invalid_value):
35+
msg = f"['“{invalid_value}” value must be an Object Id.']"
36+
with self.assertRaisesMessage(exceptions.ValidationError, msg):
37+
f.to_python(invalid_value)
38+
39+
def test_get_prep_value_string(self):
40+
value = "1" * 24
41+
self.assertEqual(ObjectIdField().get_prep_value(value), ObjectId(value))
42+
43+
def test_get_prep_value_objectid(self):
44+
value = ObjectId("1" * 24)
45+
self.assertEqual(ObjectIdField().get_prep_value(value), value)
46+
47+
def test_get_prep_value_null(self):
48+
self.assertIsNone(ObjectIdField().get_prep_value(None))
49+
50+
def test_get_prep_value_invalid_values(self):
51+
f = ObjectIdField()
52+
f.name = "test"
53+
for invalid_value in ["None", {}, [], 123]:
54+
with self.subTest(invalid_value=invalid_value):
55+
msg = f"Field 'test' expected an ObjectId but got {invalid_value!r}."
56+
with self.assertRaisesMessage(ValueError, msg):
57+
f.get_prep_value(invalid_value)

tests/queries_/models.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from django.db import models
22

3+
from django_mongodb.fields import ObjectIdAutoField, ObjectIdField
4+
35

46
class Author(models.Model):
57
name = models.CharField(max_length=10)
@@ -14,3 +16,40 @@ class Book(models.Model):
1416

1517
def __str__(self):
1618
return self.title
19+
20+
21+
class Tag(models.Model):
22+
name = models.CharField(max_length=10)
23+
parent = models.ForeignKey(
24+
"self",
25+
models.SET_NULL,
26+
blank=True,
27+
null=True,
28+
related_name="children",
29+
)
30+
group_id = ObjectIdField(null=True)
31+
32+
def __str__(self):
33+
return self.name
34+
35+
36+
class Order(models.Model):
37+
id = ObjectIdAutoField(primary_key=True)
38+
name = models.CharField(max_length=12, null=True, default="")
39+
40+
class Meta:
41+
ordering = ("pk",)
42+
43+
def __str__(self):
44+
return str(self.pk)
45+
46+
47+
class OrderItem(models.Model):
48+
order = models.ForeignKey(Order, models.CASCADE, related_name="items")
49+
status = ObjectIdField(null=True)
50+
51+
class Meta:
52+
ordering = ("pk",)
53+
54+
def __str__(self):
55+
return str(self.pk)

tests/queries_/test_objectid.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from bson import ObjectId
2+
from django.test import TestCase
3+
4+
from .models import Order, OrderItem, Tag
5+
6+
7+
class ObjectIdTests(TestCase):
8+
@classmethod
9+
def setUpTestData(cls):
10+
cls.group_id_str_1 = "1" * 24
11+
cls.group_id_obj_1 = ObjectId(cls.group_id_str_1)
12+
cls.group_id_str_2 = "2" * 24
13+
cls.group_id_obj_2 = ObjectId(cls.group_id_str_2)
14+
15+
cls.t1 = Tag.objects.create(name="t1")
16+
cls.t2 = Tag.objects.create(name="t2", parent=cls.t1)
17+
cls.t3 = Tag.objects.create(name="t3", parent=cls.t1, group_id=cls.group_id_str_1)
18+
cls.t4 = Tag.objects.create(name="t4", parent=cls.t3, group_id=cls.group_id_obj_2)
19+
cls.t5 = Tag.objects.create(name="t5", parent=cls.t3)
20+
21+
def test_filter_group_id_is_null_false(self):
22+
"""Filter objects where group_id is not null."""
23+
qs = Tag.objects.filter(group_id__isnull=False).order_by("name")
24+
self.assertSequenceEqual(qs, [self.t3, self.t4])
25+
26+
def test_filter_group_id_is_null_true(self):
27+
"""Filter objects where group_id is null."""
28+
qs = Tag.objects.filter(group_id__isnull=True).order_by("name")
29+
self.assertSequenceEqual(qs, [self.t1, self.t2, self.t5])
30+
31+
def test_filter_group_id_equal_str(self):
32+
"""Filter by group_id with a specific string value."""
33+
qs = Tag.objects.filter(group_id=self.group_id_str_1).order_by("name")
34+
self.assertSequenceEqual(qs, [self.t3])
35+
36+
def test_filter_group_id_equal_obj(self):
37+
"""Filter by group_id with a specific ObjectId value."""
38+
qs = Tag.objects.filter(group_id=self.group_id_obj_1).order_by("name")
39+
self.assertSequenceEqual(qs, [self.t3])
40+
41+
def test_filter_group_id_in_str_values(self):
42+
"""Filter by group_id with string values in a list."""
43+
ids = [self.group_id_str_1, self.group_id_str_2]
44+
qs = Tag.objects.filter(group_id__in=ids).order_by("name")
45+
self.assertSequenceEqual(qs, [self.t3, self.t4])
46+
47+
def test_filter_group_id_in_obj_values(self):
48+
"""Filter by group_id with ObjectId values in a list."""
49+
ids = [self.group_id_obj_1, self.group_id_obj_2]
50+
qs = Tag.objects.filter(group_id__in=ids).order_by("name")
51+
self.assertSequenceEqual(qs, [self.t3, self.t4])
52+
53+
def test_filter_group_id_equal_subquery(self):
54+
"""Filter by group_id using a subquery."""
55+
subquery = Tag.objects.filter(name="t3").values("group_id")
56+
qs = Tag.objects.filter(group_id__in=subquery).order_by("name")
57+
self.assertSequenceEqual(qs, [self.t3])
58+
59+
def test_filter_group_id_in_subquery(self):
60+
"""Filter by group_id using a subquery with multiple values."""
61+
subquery = Tag.objects.filter(name__in=["t3", "t4"]).values("group_id")
62+
qs = Tag.objects.filter(group_id__in=subquery).order_by("name")
63+
self.assertSequenceEqual(qs, [self.t3, self.t4])
64+
65+
def test_filter_parent_by_children_values_str(self):
66+
"""Query to select parents of children with specific string group_id."""
67+
child_ids = Tag.objects.filter(group_id=self.group_id_str_1).values_list("id", flat=True)
68+
parent_qs = Tag.objects.filter(children__id__in=child_ids).distinct().order_by("name")
69+
self.assertSequenceEqual(parent_qs, [self.t1])
70+
71+
def test_filter_parent_by_children_values_obj(self):
72+
"""Query to select parents of children with specific ObjectId group_id."""
73+
child_ids = Tag.objects.filter(group_id=self.group_id_obj_1).values_list("id", flat=True)
74+
parent_qs = Tag.objects.filter(children__id__in=child_ids).distinct().order_by("name")
75+
self.assertSequenceEqual(parent_qs, [self.t1])
76+
77+
def test_filter_group_id_union_with_str(self):
78+
"""Combine queries using union with string values."""
79+
qs_a = Tag.objects.filter(group_id=self.group_id_str_1)
80+
qs_b = Tag.objects.filter(group_id=self.group_id_str_2)
81+
union_qs = qs_a.union(qs_b).order_by("name")
82+
self.assertSequenceEqual(union_qs, [self.t3, self.t4])
83+
84+
def test_filter_group_id_union_with_obj(self):
85+
"""Combine queries using union with ObjectId values."""
86+
qs_a = Tag.objects.filter(group_id=self.group_id_obj_1)
87+
qs_b = Tag.objects.filter(group_id=self.group_id_obj_2)
88+
union_qs = qs_a.union(qs_b).order_by("name")
89+
self.assertSequenceEqual(union_qs, [self.t3, self.t4])
90+
91+
def test_filter_invalid_object_id(self):
92+
value = "value1"
93+
msg = f"Field 'group_id' expected an ObjectId but got '{value}'."
94+
with self.assertRaisesMessage(ValueError, msg):
95+
Tag.objects.filter(group_id=value)
96+
97+
def test_values_in_subquery(self):
98+
# If a values() queryset is used, then the given values will be used
99+
# instead of forcing use of the relation's field.
100+
o1 = Order.objects.create()
101+
o2 = Order.objects.create()
102+
oi1 = OrderItem.objects.create(order=o1, status=None)
103+
oi1.status = oi1.pk
104+
oi1.save()
105+
OrderItem.objects.create(order=o2, status=None)
106+
# The query below should match o1 as it has related order_item with
107+
# id == status.
108+
self.assertSequenceEqual(
109+
Order.objects.filter(items__in=OrderItem.objects.values_list("status")),
110+
[o1],
111+
)

0 commit comments

Comments
 (0)