Skip to content

Commit ba6147c

Browse files
committed
add ListField and EmbeddedModelField tests
1 parent 5a247e8 commit ba6147c

File tree

4 files changed

+539
-0
lines changed

4 files changed

+539
-0
lines changed

tests/mongo_fields/__init__.py

Whitespace-only changes.

tests/mongo_fields/models.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from django_mongodb.fields import EmbeddedModelField, ListField
2+
3+
from django.db import models
4+
5+
6+
def count_calls(func):
7+
8+
def wrapper(*args, **kwargs):
9+
wrapper.calls += 1
10+
return func(*args, **kwargs)
11+
12+
wrapper.calls = 0
13+
14+
return wrapper
15+
16+
17+
class ReferenceList(models.Model):
18+
keys = ListField(models.ForeignKey("Model", models.CASCADE))
19+
20+
21+
class Model(models.Model):
22+
pass
23+
24+
25+
class Target(models.Model):
26+
index = models.IntegerField()
27+
28+
29+
class DecimalModel(models.Model):
30+
decimal = models.DecimalField(max_digits=9, decimal_places=2)
31+
32+
33+
class DecimalKey(models.Model):
34+
decimal = models.DecimalField(max_digits=9, decimal_places=2, primary_key=True)
35+
36+
37+
class DecimalParent(models.Model):
38+
child = models.ForeignKey(DecimalKey, models.CASCADE)
39+
40+
41+
class DecimalsList(models.Model):
42+
decimals = ListField(models.ForeignKey(DecimalKey, models.CASCADE))
43+
44+
45+
class OrderedListModel(models.Model):
46+
ordered_ints = ListField(
47+
models.IntegerField(max_length=500),
48+
default=[],
49+
ordering=count_calls(lambda x: x),
50+
null=True,
51+
)
52+
ordered_nullable = ListField(ordering=lambda x: x, null=True)
53+
54+
55+
class ListModel(models.Model):
56+
integer = models.IntegerField(primary_key=True)
57+
floating_point = models.FloatField()
58+
names = ListField(models.CharField)
59+
names_with_default = ListField(models.CharField(max_length=500), default=[])
60+
names_nullable = ListField(models.CharField(max_length=500), null=True)
61+
62+
63+
class EmbeddedModelFieldModel(models.Model):
64+
simple = EmbeddedModelField("EmbeddedModel", null=True)
65+
simple_untyped = EmbeddedModelField(null=True)
66+
decimal_parent = EmbeddedModelField(DecimalParent, null=True)
67+
# typed_list = ListField(EmbeddedModelField('SetModel'))
68+
typed_list2 = ListField(EmbeddedModelField("EmbeddedModel"))
69+
untyped_list = ListField(EmbeddedModelField())
70+
# untyped_dict = DictField(EmbeddedModelField())
71+
ordered_list = ListField(EmbeddedModelField(), ordering=lambda obj: obj.index)
72+
73+
74+
class EmbeddedModel(models.Model):
75+
some_relation = models.ForeignKey(Target, models.CASCADE, null=True)
76+
someint = models.IntegerField(db_column="custom")
77+
auto_now = models.DateTimeField(auto_now=True)
78+
auto_now_add = models.DateTimeField(auto_now_add=True)
79+
80+
81+
class Child(models.Model):
82+
pass
83+
84+
85+
class Parent(models.Model):
86+
id = models.IntegerField(primary_key=True)
87+
integer_list = ListField(models.IntegerField)
88+
89+
# integer_dict = DictField(models.IntegerField)
90+
embedded_list = ListField(EmbeddedModelField(Child))
91+
92+
93+
# embedded_dict = DictField(EmbeddedModelField(Child))
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import time
2+
from decimal import Decimal
3+
4+
from django.db import models
5+
from django.test import TestCase
6+
7+
from .models import (
8+
Child,
9+
DecimalKey,
10+
DecimalParent,
11+
EmbeddedModel,
12+
EmbeddedModelFieldModel,
13+
OrderedListModel,
14+
Parent,
15+
Target,
16+
)
17+
18+
19+
class EmbeddedModelFieldTests(TestCase):
20+
21+
def assertEqualDatetime(self, d1, d2):
22+
"""Compares d1 and d2, ignoring microseconds."""
23+
self.assertEqual(d1.replace(microsecond=0), d2.replace(microsecond=0))
24+
25+
def assertNotEqualDatetime(self, d1, d2):
26+
self.assertNotEqual(d1.replace(microsecond=0), d2.replace(microsecond=0))
27+
28+
def _simple_instance(self):
29+
EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint="5"))
30+
return EmbeddedModelFieldModel.objects.get()
31+
32+
def test_simple(self):
33+
instance = self._simple_instance()
34+
self.assertIsInstance(instance.simple, EmbeddedModel)
35+
# Make sure get_prep_value is called.
36+
self.assertEqual(instance.simple.someint, 5)
37+
# Primary keys should not be populated...
38+
self.assertEqual(instance.simple.id, None)
39+
# ... unless set explicitly.
40+
instance.simple.id = instance.id
41+
instance.save()
42+
instance = EmbeddedModelFieldModel.objects.get()
43+
self.assertEqual(instance.simple.id, instance.id)
44+
45+
def _test_pre_save(self, instance, get_field):
46+
# Make sure field.pre_save is called for embedded objects.
47+
48+
instance.save()
49+
auto_now = get_field(instance).auto_now
50+
auto_now_add = get_field(instance).auto_now_add
51+
self.assertNotEqual(auto_now, None)
52+
self.assertNotEqual(auto_now_add, None)
53+
54+
time.sleep(1) # FIXME
55+
instance.save()
56+
self.assertNotEqualDatetime(
57+
get_field(instance).auto_now, get_field(instance).auto_now_add
58+
)
59+
60+
instance = EmbeddedModelFieldModel.objects.get()
61+
instance.save()
62+
# auto_now_add shouldn't have changed now, but auto_now should.
63+
self.assertEqualDatetime(get_field(instance).auto_now_add, auto_now_add)
64+
self.assertGreater(get_field(instance).auto_now, auto_now)
65+
66+
def test_pre_save(self):
67+
obj = EmbeddedModelFieldModel(simple=EmbeddedModel())
68+
self._test_pre_save(obj, lambda instance: instance.simple)
69+
70+
def test_pre_save_untyped(self):
71+
obj = EmbeddedModelFieldModel(simple_untyped=EmbeddedModel())
72+
self._test_pre_save(obj, lambda instance: instance.simple_untyped)
73+
74+
def test_pre_save_in_list(self):
75+
obj = EmbeddedModelFieldModel(untyped_list=[EmbeddedModel()])
76+
self._test_pre_save(obj, lambda instance: instance.untyped_list[0])
77+
78+
def _test_pre_save_in_dict(self):
79+
obj = EmbeddedModelFieldModel(untyped_dict={"a": EmbeddedModel()})
80+
self._test_pre_save(obj, lambda instance: instance.untyped_dict["a"])
81+
82+
def test_pre_save_list(self):
83+
# Also make sure auto_now{,add} works for embedded object *lists*.
84+
EmbeddedModelFieldModel.objects.create(typed_list2=[EmbeddedModel()])
85+
instance = EmbeddedModelFieldModel.objects.get()
86+
87+
auto_now = instance.typed_list2[0].auto_now
88+
auto_now_add = instance.typed_list2[0].auto_now_add
89+
self.assertNotEqual(auto_now, None)
90+
self.assertNotEqual(auto_now_add, None)
91+
92+
instance.typed_list2.append(EmbeddedModel())
93+
instance.save()
94+
instance = EmbeddedModelFieldModel.objects.get()
95+
96+
self.assertEqualDatetime(instance.typed_list2[0].auto_now_add, auto_now_add)
97+
self.assertGreater(instance.typed_list2[0].auto_now, auto_now)
98+
self.assertNotEqual(instance.typed_list2[1].auto_now, None)
99+
self.assertNotEqual(instance.typed_list2[1].auto_now_add, None)
100+
101+
def test_error_messages(self):
102+
for kwargs, expected in (
103+
({"simple": 42}, EmbeddedModel),
104+
({"simple_untyped": 42}, models.Model),
105+
# ({"typed_list": [EmbeddedModel()]},), # SetModel),
106+
):
107+
self.assertRaisesMessage(
108+
TypeError,
109+
"Expected instance of type %r" % expected,
110+
EmbeddedModelFieldModel(**kwargs).save,
111+
)
112+
113+
def test_typed_listfield(self):
114+
EmbeddedModelFieldModel.objects.create(
115+
# typed_list=[SetModel(setfield=range(3)), SetModel(setfield=range(9))],
116+
ordered_list=[Target(index=i) for i in range(5, 0, -1)],
117+
)
118+
obj = EmbeddedModelFieldModel.objects.get()
119+
# self.assertIn(5, obj.typed_list[1].setfield)
120+
self.assertEqual([target.index for target in obj.ordered_list], range(1, 6))
121+
122+
def test_untyped_listfield(self):
123+
EmbeddedModelFieldModel.objects.create(
124+
untyped_list=[
125+
EmbeddedModel(someint=7),
126+
OrderedListModel(ordered_ints=list(range(5, 0, -1))),
127+
# SetModel(setfield=[1, 2, 2, 3]),
128+
]
129+
)
130+
instances = EmbeddedModelFieldModel.objects.get().untyped_list
131+
for instance, cls in zip(
132+
instances, [EmbeddedModel, OrderedListModel] # SetModel]
133+
):
134+
self.assertIsInstance(instance, cls)
135+
self.assertNotEqual(instances[0].auto_now, None)
136+
self.assertEqual(instances[1].ordered_ints, range(1, 6))
137+
138+
def _test_untyped_dict(self):
139+
EmbeddedModelFieldModel.objects.create(
140+
untyped_dict={
141+
# "a": SetModel(setfield=range(3)),
142+
# "b": DictModel(dictfield={"a": 1, "b": 2}),
143+
# "c": DictModel(dictfield={}, auto_now={"y": 1}),
144+
}
145+
)
146+
# data = EmbeddedModelFieldModel.objects.get().untyped_dict
147+
# self.assertIsInstance(data["a"], SetModel)
148+
# self.assertNotEqual(data["c"].auto_now["y"], None)
149+
150+
def test_foreignkey_in_embedded_object(self):
151+
simple = EmbeddedModel(some_relation=Target.objects.create(index=1))
152+
obj = EmbeddedModelFieldModel.objects.create(simple=simple)
153+
simple = EmbeddedModelFieldModel.objects.get().simple
154+
self.assertNotIn("some_relation", simple.__dict__)
155+
self.assertIsInstance(simple.__dict__["some_relation_id"], type(obj.id))
156+
self.assertIsInstance(simple.some_relation, Target)
157+
158+
def test_embedded_field_with_foreign_conversion(self):
159+
decimal = DecimalKey.objects.create(decimal=Decimal("1.5"))
160+
decimal_parent = DecimalParent.objects.create(child=decimal)
161+
EmbeddedModelFieldModel.objects.create(decimal_parent=decimal_parent)
162+
163+
def test_update(self):
164+
"""
165+
Test that update can be used on an a subset of objects
166+
containing collections of embedded instances; see issue #13.
167+
Also ensure that updated values are coerced according to
168+
collection field.
169+
"""
170+
child1 = Child.objects.create()
171+
child2 = Child.objects.create()
172+
parent = Parent.objects.create(
173+
pk=1,
174+
integer_list=[1],
175+
# integer_dict={"a": 2},
176+
embedded_list=[child1],
177+
# embedded_dict={"a": child2},
178+
)
179+
Parent.objects.filter(pk=1).update(
180+
integer_list=["3"],
181+
# integer_dict={"b": "3"},
182+
embedded_list=[child2],
183+
# embedded_dict={"b": child1},
184+
)
185+
parent = Parent.objects.get()
186+
self.assertEqual(parent.integer_list, [3])
187+
# self.assertEqual(parent.integer_dict, {"b": 3})
188+
self.assertEqual(parent.embedded_list, [child2])
189+
# self.assertEqual(parent.embedded_dict, {"b": child1})

0 commit comments

Comments
 (0)