Skip to content

Commit 8cace6c

Browse files
committed
add support for EmbeddedModelField
1 parent ca8ac6a commit 8cace6c

File tree

5 files changed

+326
-2
lines changed

5 files changed

+326
-2
lines changed

django_mongodb/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ def execute_sql(self, result_type):
737737
elif hasattr(value, "prepare_database_save"):
738738
if field.remote_field:
739739
value = value.prepare_database_save(field)
740-
else:
740+
elif not hasattr(field, "embedded_model"):
741741
raise TypeError(
742742
f"Tried to update field {field} with a model "
743743
f"instance, {value!r}. Use a value compatible with "

django_mongodb/fields/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .auto import ObjectIdAutoField
22
from .duration import register_duration_field
3+
from .embedded_model import EmbeddedModelField
34
from .json import register_json_field
45

5-
__all__ = ["register_fields", "ObjectIdAutoField"]
6+
__all__ = ["register_fields", "EmbeddedModelField", "ObjectIdAutoField"]
67

78

89
def register_fields():
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from importlib import import_module
2+
3+
from django.db import IntegrityError, models
4+
from django.db.models.fields.related import lazy_related_operation
5+
6+
7+
class EmbeddedModelField(models.Field):
8+
"""Field that stores a model instance."""
9+
10+
def __init__(self, embedded_model=None, *args, **kwargs):
11+
"""
12+
`embedded_model` is the model class of the instance that will be
13+
stored. Like other relational fields, it may also be passed as a
14+
string.
15+
"""
16+
self.embedded_model = embedded_model
17+
super().__init__(*args, **kwargs)
18+
19+
def deconstruct(self):
20+
name, path, args, kwargs = super().deconstruct()
21+
if path.startswith("django_mongodb.fields.embedded_model"):
22+
path = path.replace("django_mongodb.fields.embedded_model", "django_mongodb.fields")
23+
if self.embedded_model:
24+
kwargs["embedded_model"] = self.embedded_model
25+
return name, path, args, kwargs
26+
27+
def get_internal_type(self):
28+
return "EmbeddedModelField"
29+
30+
def _set_model(self, model):
31+
"""
32+
Resolve embedded model class once the field knows the model it belongs
33+
to.
34+
35+
If the model argument passed to __init__() was a string, resolve that
36+
string to the corresponding model class, similar to relation fields.
37+
However, we need to know our own model to generate a valid key
38+
for the embedded model class lookup and EmbeddedModelFields are
39+
not contributed_to_class if used in iterable fields. Thus the
40+
collection field sets this field's "model" attribute in its
41+
contribute_to_class().
42+
"""
43+
self._model = model
44+
if model is not None and isinstance(self.embedded_model, str):
45+
46+
def _resolve_lookup(_, resolved_model):
47+
self.embedded_model = resolved_model
48+
49+
lazy_related_operation(_resolve_lookup, model, self.embedded_model)
50+
51+
model = property(lambda self: self._model, _set_model)
52+
53+
def stored_model(self, column_values):
54+
"""
55+
Return the fixed embedded_model this field was initialized
56+
with (typed embedding) or tries to determine the model from
57+
_module / _model keys stored together with column_values
58+
(untyped embedding).
59+
60+
Give precedence to the field's definition model, as silently using a
61+
differing serialized one could hide some data integrity problems.
62+
63+
Note that a single untyped EmbeddedModelField may process
64+
instances of different models (especially when used as a type
65+
of a collection field).
66+
"""
67+
module = column_values.pop("_module", None)
68+
model = column_values.pop("_model", None)
69+
if self.embedded_model is not None:
70+
return self.embedded_model
71+
if module is not None:
72+
return getattr(import_module(module), model)
73+
raise IntegrityError(
74+
"Untyped EmbeddedModelField trying to load data without serialized model class info."
75+
)
76+
77+
def from_db_value(self, value, expression, connection):
78+
return self.to_python(value)
79+
80+
def to_python(self, value):
81+
"""
82+
Passes embedded model fields' values through embedded fields
83+
to_python methods and reinstiatates the embedded instance.
84+
85+
We expect to receive a field.attname => value dict together
86+
with a model class from back-end database deconversion (which
87+
needs to know fields of the model beforehand).
88+
"""
89+
# Either the model class has already been determined during
90+
# deconverting values from the database or we've got a dict
91+
# from a deserializer that may contain model class info.
92+
if isinstance(value, tuple):
93+
embedded_model, attribute_values = value
94+
elif isinstance(value, dict):
95+
embedded_model = self.stored_model(value)
96+
attribute_values = value
97+
else:
98+
return value
99+
# Create the model instance.
100+
instance = embedded_model(
101+
**{
102+
# Pass values through respective fields' to_python(), leaving
103+
# fields for which no value is specified uninitialized.
104+
field.attname: field.to_python(attribute_values[field.attname])
105+
for field in embedded_model._meta.fields
106+
if field.attname in attribute_values
107+
}
108+
)
109+
instance._state.adding = False
110+
return instance
111+
112+
def get_db_prep_save(self, embedded_instance, connection):
113+
"""
114+
Apply pre_save() and get_db_prep_save() of embedded instance
115+
fields and passes a field => value mapping down to database
116+
type conversions.
117+
118+
The embedded instance will be saved as a column => value dict
119+
in the end (possibly augmented with info about instance's model
120+
for untyped embedding), but because we need to apply database
121+
type conversions on embedded instance fields' values and for
122+
these we need to know fields those values come from, we need to
123+
entrust the database layer with creating the dict.
124+
"""
125+
if embedded_instance is None:
126+
return None
127+
# The field's value should be an instance of the model given in
128+
# its declaration or at least of some model.
129+
embedded_model = self.embedded_model or models.Model
130+
if not isinstance(embedded_instance, embedded_model):
131+
raise TypeError(
132+
f"Expected instance of type {embedded_model!r}, not {type(embedded_instance)!r}."
133+
)
134+
# Apply pre_save() and get_db_prep_save() of embedded instance
135+
# fields, create the field => value mapping to be passed to
136+
# storage preprocessing.
137+
field_values = {}
138+
add = embedded_instance._state.adding
139+
for field in embedded_instance._meta.fields:
140+
value = field.get_db_prep_save(
141+
field.pre_save(embedded_instance, add), connection=connection
142+
)
143+
# Exclude unset primary keys (e.g. {'id': None}).
144+
if field.primary_key and value is None:
145+
continue
146+
field_values[field.attname] = value
147+
if self.embedded_model is None:
148+
# Untyped fields must store model info alongside values.
149+
field_values.update(
150+
(
151+
("_module", embedded_instance.__class__.__module__),
152+
("_model", embedded_instance.__class__.__name__),
153+
)
154+
)
155+
# This instance will exist in the database soon.
156+
# TODO.XXX: Ensure that this doesn't cause race conditions.
157+
embedded_instance._state.adding = False
158+
return field_values
159+
160+
def validate(self, value, model_instance):
161+
super().validate(value, model_instance)
162+
if self.embedded_model is None:
163+
return
164+
for field in self.embedded_model._meta.fields:
165+
attname = field.attname
166+
field.validate(getattr(value, attname), model_instance)

tests/model_fields_/models.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from django.db import models
2+
3+
from django_mongodb.fields import EmbeddedModelField
4+
5+
6+
class Target(models.Model):
7+
index = models.IntegerField()
8+
9+
10+
class DecimalModel(models.Model):
11+
decimal = models.DecimalField(max_digits=9, decimal_places=2)
12+
13+
14+
class DecimalKey(models.Model):
15+
decimal = models.DecimalField(max_digits=9, decimal_places=2, primary_key=True)
16+
17+
18+
class DecimalParent(models.Model):
19+
child = models.ForeignKey(DecimalKey, models.CASCADE)
20+
21+
22+
class EmbeddedModelFieldModel(models.Model):
23+
simple = EmbeddedModelField("EmbeddedModel", null=True, blank=True)
24+
untyped = EmbeddedModelField(null=True, blank=True)
25+
decimal_parent = EmbeddedModelField(DecimalParent, null=True, blank=True)
26+
27+
28+
class EmbeddedModel(models.Model):
29+
some_relation = models.ForeignKey(Target, models.CASCADE, null=True, blank=True)
30+
someint = models.IntegerField(db_column="custom_column")
31+
auto_now = models.DateTimeField(auto_now=True)
32+
auto_now_add = models.DateTimeField(auto_now_add=True)
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import time
2+
from decimal import Decimal
3+
4+
from django.core.exceptions import ValidationError
5+
from django.db import models
6+
from django.test import SimpleTestCase, TestCase
7+
8+
from django_mongodb.fields import EmbeddedModelField
9+
10+
from .models import (
11+
DecimalKey,
12+
DecimalParent,
13+
EmbeddedModel,
14+
EmbeddedModelFieldModel,
15+
Target,
16+
)
17+
18+
19+
class MethodTests(SimpleTestCase):
20+
def test_deconstruct(self):
21+
field = EmbeddedModelField()
22+
name, path, args, kwargs = field.deconstruct()
23+
self.assertEqual(path, "django_mongodb.fields.EmbeddedModelField")
24+
self.assertEqual(args, [])
25+
self.assertEqual(kwargs, {})
26+
27+
def test_deconstruct_with_model(self):
28+
field = EmbeddedModelField("EmbeddedModel", null=True)
29+
name, path, args, kwargs = field.deconstruct()
30+
self.assertEqual(path, "django_mongodb.fields.EmbeddedModelField")
31+
self.assertEqual(args, [])
32+
self.assertEqual(kwargs, {"embedded_model": "EmbeddedModel", "null": True})
33+
34+
def test_validate(self):
35+
instance = EmbeddedModelFieldModel(simple=EmbeddedModel(someint=None))
36+
# This isn't quite right because "someint" is the field that's non-null.
37+
msg = "{'simple': ['This field cannot be null.']}"
38+
with self.assertRaisesMessage(ValidationError, msg):
39+
instance.full_clean()
40+
41+
42+
class QueryingTests(TestCase):
43+
def assertEqualDatetime(self, d1, d2):
44+
"""Compares d1 and d2, ignoring microseconds."""
45+
self.assertEqual(d1.replace(microsecond=0), d2.replace(microsecond=0))
46+
47+
def assertNotEqualDatetime(self, d1, d2):
48+
self.assertNotEqual(d1.replace(microsecond=0), d2.replace(microsecond=0))
49+
50+
def test_save_load(self):
51+
EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint="5"))
52+
instance = EmbeddedModelFieldModel.objects.get()
53+
self.assertIsInstance(instance.simple, EmbeddedModel)
54+
# Make sure get_prep_value is called.
55+
self.assertEqual(instance.simple.someint, 5)
56+
# Primary keys should not be populated...
57+
self.assertEqual(instance.simple.id, None)
58+
# ... unless set explicitly.
59+
instance.simple.id = instance.id
60+
instance.save()
61+
instance = EmbeddedModelFieldModel.objects.get()
62+
self.assertEqual(instance.simple.id, instance.id)
63+
64+
def test_save_load_untyped(self):
65+
EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint="5"))
66+
instance = EmbeddedModelFieldModel.objects.get()
67+
self.assertIsInstance(instance.simple, EmbeddedModel)
68+
# Make sure get_prep_value is called.
69+
self.assertEqual(instance.simple.someint, 5)
70+
# Primary keys should not be populated...
71+
self.assertEqual(instance.simple.id, None)
72+
# ... unless set explicitly.
73+
instance.simple.id = instance.id
74+
instance.save()
75+
instance = EmbeddedModelFieldModel.objects.get()
76+
self.assertEqual(instance.simple.id, instance.id)
77+
78+
def _test_pre_save(self, instance, get_field):
79+
# Field.pre_save() is called on embedded model fields.
80+
81+
instance.save()
82+
auto_now = get_field(instance).auto_now
83+
auto_now_add = get_field(instance).auto_now_add
84+
self.assertNotEqual(auto_now, None)
85+
self.assertNotEqual(auto_now_add, None)
86+
87+
time.sleep(1) # FIXME
88+
instance.save()
89+
self.assertNotEqualDatetime(get_field(instance).auto_now, get_field(instance).auto_now_add)
90+
91+
instance = EmbeddedModelFieldModel.objects.get()
92+
instance.save()
93+
# auto_now_add shouldn't have changed now, but auto_now should.
94+
self.assertEqualDatetime(get_field(instance).auto_now_add, auto_now_add)
95+
self.assertGreater(get_field(instance).auto_now, auto_now)
96+
97+
def test_pre_save(self):
98+
obj = EmbeddedModelFieldModel(simple=EmbeddedModel())
99+
self._test_pre_save(obj, lambda instance: instance.simple)
100+
101+
def test_pre_save_untyped(self):
102+
obj = EmbeddedModelFieldModel(untyped=EmbeddedModel())
103+
self._test_pre_save(obj, lambda instance: instance.untyped)
104+
105+
def test_error_messages(self):
106+
for model_kwargs, expected in (
107+
({"simple": 42}, EmbeddedModel),
108+
({"untyped": 42}, models.Model),
109+
):
110+
msg = "Expected instance of type %r" % expected
111+
with self.assertRaisesMessage(TypeError, msg):
112+
EmbeddedModelFieldModel(**model_kwargs).save()
113+
114+
def test_foreign_key_in_embedded_object(self):
115+
simple = EmbeddedModel(some_relation=Target.objects.create(index=1))
116+
obj = EmbeddedModelFieldModel.objects.create(simple=simple)
117+
simple = EmbeddedModelFieldModel.objects.get().simple
118+
self.assertNotIn("some_relation", simple.__dict__)
119+
self.assertIsInstance(simple.__dict__["some_relation_id"], type(obj.id))
120+
self.assertIsInstance(simple.some_relation, Target)
121+
122+
def test_embedded_field_with_foreign_conversion(self):
123+
decimal = DecimalKey.objects.create(decimal=Decimal("1.5"))
124+
decimal_parent = DecimalParent.objects.create(child=decimal)
125+
EmbeddedModelFieldModel.objects.create(decimal_parent=decimal_parent)

0 commit comments

Comments
 (0)