Skip to content

Commit 2ccac6a

Browse files
committed
add EmbeddedModelField
1 parent 419b97e commit 2ccac6a

File tree

5 files changed

+335
-3
lines changed

5 files changed

+335
-3
lines changed

django_mongodb_backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ def execute_sql(self, result_type):
741741
elif hasattr(value, "prepare_database_save"):
742742
if field.remote_field:
743743
value = value.prepare_database_save(field)
744-
else:
744+
elif not hasattr(field, "embedded_model"):
745745
raise TypeError(
746746
f"Tried to update field {field} with a model "
747747
f"instance, {value!r}. Use a value compatible with "

django_mongodb_backend/fields/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
from .array import ArrayField
22
from .auto import ObjectIdAutoField
33
from .duration import register_duration_field
4+
from .embedded_model import EmbeddedModelField
45
from .json import register_json_field
56
from .objectid import ObjectIdField
67

7-
__all__ = ["register_fields", "ArrayField", "ObjectIdAutoField", "ObjectIdField"]
8+
__all__ = [
9+
"register_fields",
10+
"ArrayField",
11+
"EmbeddedModelField",
12+
"ObjectIdAutoField",
13+
"ObjectIdField",
14+
]
815

916

1017
def register_fields():
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from importlib import import_module
2+
3+
from django.db import IntegrityError, models
4+
from django.db.models.fields.related import lazy_related_operation
5+
6+
7+
class EmbeddedModelField(models.Field):
8+
"""Field that stores a model instance."""
9+
10+
def __init__(self, embedded_model=None, *args, **kwargs):
11+
"""
12+
`embedded_model` is the model class of the instance that will be
13+
stored. Like other relational fields, it may also be passed as a
14+
string.
15+
"""
16+
self.embedded_model = embedded_model
17+
super().__init__(*args, **kwargs)
18+
19+
def deconstruct(self):
20+
name, path, args, kwargs = super().deconstruct()
21+
if path.startswith("django_mongodb_backend.fields.embedded_model"):
22+
path = path.replace(
23+
"django_mongodb_backend.fields.embedded_model", "django_mongodb_backend.fields"
24+
)
25+
if self.embedded_model:
26+
kwargs["embedded_model"] = self.embedded_model
27+
return name, path, args, kwargs
28+
29+
def get_internal_type(self):
30+
return "EmbeddedModelField"
31+
32+
def _set_model(self, model):
33+
"""
34+
Resolve embedded model class once the field knows the model it belongs
35+
to.
36+
37+
If the model argument passed to __init__() was a string, resolve that
38+
string to the corresponding model class, similar to relation fields.
39+
However, we need to know our own model to generate a valid key
40+
for the embedded model class lookup and EmbeddedModelFields are
41+
not contributed_to_class if used in iterable fields. Thus the
42+
collection field sets this field's "model" attribute in its
43+
contribute_to_class().
44+
"""
45+
self._model = model
46+
if model is not None and isinstance(self.embedded_model, str):
47+
48+
def _resolve_lookup(_, resolved_model):
49+
self.embedded_model = resolved_model
50+
51+
lazy_related_operation(_resolve_lookup, model, self.embedded_model)
52+
53+
model = property(lambda self: self._model, _set_model)
54+
55+
def stored_model(self, column_values):
56+
"""
57+
Return the fixed embedded_model this field was initialized
58+
with (typed embedding) or tries to determine the model from
59+
_module / _model keys stored together with column_values
60+
(untyped embedding).
61+
62+
Give precedence to the field's definition model, as silently using a
63+
differing serialized one could hide some data integrity problems.
64+
65+
Note that a single untyped EmbeddedModelField may process
66+
instances of different models (especially when used as a type
67+
of a collection field).
68+
"""
69+
module = column_values.pop("_module", None)
70+
model = column_values.pop("_model", None)
71+
if self.embedded_model is not None:
72+
return self.embedded_model
73+
if module is not None:
74+
return getattr(import_module(module), model)
75+
raise IntegrityError(
76+
"Untyped EmbeddedModelField trying to load data without serialized model class info."
77+
)
78+
79+
def from_db_value(self, value, expression, connection):
80+
return self.to_python(value)
81+
82+
def to_python(self, value):
83+
"""
84+
Passes embedded model fields' values through embedded fields
85+
to_python methods and reinstiatates the embedded instance.
86+
87+
We expect to receive a field.attname => value dict together
88+
with a model class from back-end database deconversion (which
89+
needs to know fields of the model beforehand).
90+
"""
91+
# Either the model class has already been determined during
92+
# deconverting values from the database or we've got a dict
93+
# from a deserializer that may contain model class info.
94+
if isinstance(value, tuple):
95+
embedded_model, attribute_values = value
96+
elif isinstance(value, dict):
97+
embedded_model = self.stored_model(value)
98+
attribute_values = value
99+
else:
100+
return value
101+
# Create the model instance.
102+
instance = embedded_model(
103+
**{
104+
# Pass values through respective fields' to_python(), leaving
105+
# fields for which no value is specified uninitialized.
106+
field.attname: field.to_python(attribute_values[field.attname])
107+
for field in embedded_model._meta.fields
108+
if field.attname in attribute_values
109+
}
110+
)
111+
instance._state.adding = False
112+
return instance
113+
114+
def get_db_prep_save(self, embedded_instance, connection):
115+
"""
116+
Apply pre_save() and get_db_prep_save() of embedded instance
117+
fields and passes a field => value mapping down to database
118+
type conversions.
119+
120+
The embedded instance will be saved as a column => value dict
121+
in the end (possibly augmented with info about instance's model
122+
for untyped embedding), but because we need to apply database
123+
type conversions on embedded instance fields' values and for
124+
these we need to know fields those values come from, we need to
125+
entrust the database layer with creating the dict.
126+
"""
127+
if embedded_instance is None:
128+
return None
129+
# The field's value should be an instance of the model given in
130+
# its declaration or at least of some model.
131+
embedded_model = self.embedded_model or models.Model
132+
if not isinstance(embedded_instance, embedded_model):
133+
raise TypeError(
134+
f"Expected instance of type {embedded_model!r}, not {type(embedded_instance)!r}."
135+
)
136+
# Apply pre_save() and get_db_prep_save() of embedded instance
137+
# fields, create the field => value mapping to be passed to
138+
# storage preprocessing.
139+
field_values = {}
140+
add = embedded_instance._state.adding
141+
for field in embedded_instance._meta.fields:
142+
value = field.get_db_prep_save(
143+
field.pre_save(embedded_instance, add), connection=connection
144+
)
145+
# Exclude unset primary keys (e.g. {'id': None}).
146+
if field.primary_key and value is None:
147+
continue
148+
field_values[field.attname] = value
149+
if self.embedded_model is None:
150+
# Untyped fields must store model info alongside values.
151+
field_values.update(
152+
(
153+
("_module", embedded_instance.__class__.__module__),
154+
("_model", embedded_instance.__class__.__name__),
155+
)
156+
)
157+
# This instance will exist in the database soon.
158+
# TODO.XXX: Ensure that this doesn't cause race conditions.
159+
embedded_instance._state.adding = False
160+
return field_values
161+
162+
def validate(self, value, model_instance):
163+
super().validate(value, model_instance)
164+
if self.embedded_model is None:
165+
return
166+
for field in self.embedded_model._meta.fields:
167+
attname = field.attname
168+
field.validate(getattr(value, attname), model_instance)

tests/model_fields_/models.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
from django.db import models
44

5-
from django_mongodb_backend.fields import ArrayField, ObjectIdField
5+
from django_mongodb_backend.fields import ArrayField, EmbeddedModelField, ObjectIdField
66

77

8+
# ObjectIdField
89
class ObjectIdModel(models.Model):
910
field = ObjectIdField()
1011

@@ -17,6 +18,7 @@ class PrimaryKeyObjectIdModel(models.Model):
1718
field = ObjectIdField(primary_key=True)
1819

1920

21+
# ArrayField
2022
class ArrayFieldSubclass(ArrayField):
2123
def __init__(self, *args, **kwargs):
2224
super().__init__(models.IntegerField())
@@ -89,3 +91,33 @@ def get_prep_value(self, value):
8991

9092
class ArrayEnumModel(models.Model):
9193
array_of_enums = ArrayField(EnumField(max_length=20))
94+
95+
96+
# EmbeddedModelField
97+
class Target(models.Model):
98+
index = models.IntegerField()
99+
100+
101+
class DecimalModel(models.Model):
102+
decimal = models.DecimalField(max_digits=9, decimal_places=2)
103+
104+
105+
class DecimalKey(models.Model):
106+
decimal = models.DecimalField(max_digits=9, decimal_places=2, primary_key=True)
107+
108+
109+
class DecimalParent(models.Model):
110+
child = models.ForeignKey(DecimalKey, models.CASCADE)
111+
112+
113+
class EmbeddedModelFieldModel(models.Model):
114+
simple = EmbeddedModelField("EmbeddedModel", null=True, blank=True)
115+
untyped = EmbeddedModelField(null=True, blank=True)
116+
decimal_parent = EmbeddedModelField(DecimalParent, null=True, blank=True)
117+
118+
119+
class EmbeddedModel(models.Model):
120+
some_relation = models.ForeignKey(Target, models.CASCADE, null=True, blank=True)
121+
someint = models.IntegerField(db_column="custom_column")
122+
auto_now = models.DateTimeField(auto_now=True)
123+
auto_now_add = models.DateTimeField(auto_now_add=True)
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import time
2+
from decimal import Decimal
3+
4+
from django.core.exceptions import ValidationError
5+
from django.db import models
6+
from django.test import SimpleTestCase, TestCase
7+
8+
from django_mongodb_backend.fields import EmbeddedModelField
9+
10+
from .models import (
11+
DecimalKey,
12+
DecimalParent,
13+
EmbeddedModel,
14+
EmbeddedModelFieldModel,
15+
Target,
16+
)
17+
18+
19+
class MethodTests(SimpleTestCase):
20+
def test_deconstruct(self):
21+
field = EmbeddedModelField()
22+
name, path, args, kwargs = field.deconstruct()
23+
self.assertEqual(path, "django_mongodb_backend.fields.EmbeddedModelField")
24+
self.assertEqual(args, [])
25+
self.assertEqual(kwargs, {})
26+
27+
def test_deconstruct_with_model(self):
28+
field = EmbeddedModelField("EmbeddedModel", null=True)
29+
name, path, args, kwargs = field.deconstruct()
30+
self.assertEqual(path, "django_mongodb_backend.fields.EmbeddedModelField")
31+
self.assertEqual(args, [])
32+
self.assertEqual(kwargs, {"embedded_model": "EmbeddedModel", "null": True})
33+
34+
def test_validate(self):
35+
instance = EmbeddedModelFieldModel(simple=EmbeddedModel(someint=None))
36+
# This isn't quite right because "someint" is the field that's non-null.
37+
msg = "{'simple': ['This field cannot be null.']}"
38+
with self.assertRaisesMessage(ValidationError, msg):
39+
instance.full_clean()
40+
41+
42+
class QueryingTests(TestCase):
43+
def assertEqualDatetime(self, d1, d2):
44+
"""Compares d1 and d2, ignoring microseconds."""
45+
self.assertEqual(d1.replace(microsecond=0), d2.replace(microsecond=0))
46+
47+
def assertNotEqualDatetime(self, d1, d2):
48+
self.assertNotEqual(d1.replace(microsecond=0), d2.replace(microsecond=0))
49+
50+
def test_save_load(self):
51+
EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint="5"))
52+
instance = EmbeddedModelFieldModel.objects.get()
53+
self.assertIsInstance(instance.simple, EmbeddedModel)
54+
# Make sure get_prep_value is called.
55+
self.assertEqual(instance.simple.someint, 5)
56+
# Primary keys should not be populated...
57+
self.assertEqual(instance.simple.id, None)
58+
# ... unless set explicitly.
59+
instance.simple.id = instance.id
60+
instance.save()
61+
instance = EmbeddedModelFieldModel.objects.get()
62+
self.assertEqual(instance.simple.id, instance.id)
63+
64+
def test_save_load_untyped(self):
65+
EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint="5"))
66+
instance = EmbeddedModelFieldModel.objects.get()
67+
self.assertIsInstance(instance.simple, EmbeddedModel)
68+
# Make sure get_prep_value is called.
69+
self.assertEqual(instance.simple.someint, 5)
70+
# Primary keys should not be populated...
71+
self.assertEqual(instance.simple.id, None)
72+
# ... unless set explicitly.
73+
instance.simple.id = instance.id
74+
instance.save()
75+
instance = EmbeddedModelFieldModel.objects.get()
76+
self.assertEqual(instance.simple.id, instance.id)
77+
78+
def _test_pre_save(self, instance, get_field):
79+
# Field.pre_save() is called on embedded model fields.
80+
81+
instance.save()
82+
auto_now = get_field(instance).auto_now
83+
auto_now_add = get_field(instance).auto_now_add
84+
self.assertNotEqual(auto_now, None)
85+
self.assertNotEqual(auto_now_add, None)
86+
87+
time.sleep(1) # FIXME
88+
instance.save()
89+
self.assertNotEqualDatetime(get_field(instance).auto_now, get_field(instance).auto_now_add)
90+
91+
instance = EmbeddedModelFieldModel.objects.get()
92+
instance.save()
93+
# auto_now_add shouldn't have changed now, but auto_now should.
94+
self.assertEqualDatetime(get_field(instance).auto_now_add, auto_now_add)
95+
self.assertGreater(get_field(instance).auto_now, auto_now)
96+
97+
def test_pre_save(self):
98+
obj = EmbeddedModelFieldModel(simple=EmbeddedModel())
99+
self._test_pre_save(obj, lambda instance: instance.simple)
100+
101+
def test_pre_save_untyped(self):
102+
obj = EmbeddedModelFieldModel(untyped=EmbeddedModel())
103+
self._test_pre_save(obj, lambda instance: instance.untyped)
104+
105+
def test_error_messages(self):
106+
for model_kwargs, expected in (
107+
({"simple": 42}, EmbeddedModel),
108+
({"untyped": 42}, models.Model),
109+
):
110+
msg = "Expected instance of type %r" % expected
111+
with self.assertRaisesMessage(TypeError, msg):
112+
EmbeddedModelFieldModel(**model_kwargs).save()
113+
114+
def test_foreign_key_in_embedded_object(self):
115+
simple = EmbeddedModel(some_relation=Target.objects.create(index=1))
116+
obj = EmbeddedModelFieldModel.objects.create(simple=simple)
117+
simple = EmbeddedModelFieldModel.objects.get().simple
118+
self.assertNotIn("some_relation", simple.__dict__)
119+
self.assertIsInstance(simple.__dict__["some_relation_id"], type(obj.id))
120+
self.assertIsInstance(simple.some_relation, Target)
121+
122+
def test_embedded_field_with_foreign_conversion(self):
123+
decimal = DecimalKey.objects.create(decimal=Decimal("1.5"))
124+
decimal_parent = DecimalParent.objects.create(child=decimal)
125+
EmbeddedModelFieldModel.objects.create(decimal_parent=decimal_parent)

0 commit comments

Comments
 (0)