Skip to content

Commit 7688a40

Browse files
committed
remove generic support
1 parent 8cace6c commit 7688a40

File tree

3 files changed

+51
-126
lines changed

3 files changed

+51
-126
lines changed

django_mongodb/fields/embedded_model.py

Lines changed: 18 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
from importlib import import_module
2-
3-
from django.db import IntegrityError, models
1+
from django.db import models
42
from django.db.models.fields.related import lazy_related_operation
53

64

75
class EmbeddedModelField(models.Field):
86
"""Field that stores a model instance."""
97

10-
def __init__(self, embedded_model=None, *args, **kwargs):
8+
def __init__(self, embedded_model, *args, **kwargs):
119
"""
1210
`embedded_model` is the model class of the instance that will be
1311
stored. Like other relational fields, it may also be passed as a
@@ -20,8 +18,7 @@ def deconstruct(self):
2018
name, path, args, kwargs = super().deconstruct()
2119
if path.startswith("django_mongodb.fields.embedded_model"):
2220
path = path.replace("django_mongodb.fields.embedded_model", "django_mongodb.fields")
23-
if self.embedded_model:
24-
kwargs["embedded_model"] = self.embedded_model
21+
kwargs["embedded_model"] = self.embedded_model
2522
return name, path, args, kwargs
2623

2724
def get_internal_type(self):
@@ -50,60 +47,26 @@ def _resolve_lookup(_, resolved_model):
5047

5148
model = property(lambda self: self._model, _set_model)
5249

53-
def stored_model(self, column_values):
54-
"""
55-
Return the fixed embedded_model this field was initialized
56-
with (typed embedding) or tries to determine the model from
57-
_module / _model keys stored together with column_values
58-
(untyped embedding).
59-
60-
Give precedence to the field's definition model, as silently using a
61-
differing serialized one could hide some data integrity problems.
62-
63-
Note that a single untyped EmbeddedModelField may process
64-
instances of different models (especially when used as a type
65-
of a collection field).
66-
"""
67-
module = column_values.pop("_module", None)
68-
model = column_values.pop("_model", None)
69-
if self.embedded_model is not None:
70-
return self.embedded_model
71-
if module is not None:
72-
return getattr(import_module(module), model)
73-
raise IntegrityError(
74-
"Untyped EmbeddedModelField trying to load data without serialized model class info."
75-
)
76-
7750
def from_db_value(self, value, expression, connection):
7851
return self.to_python(value)
7952

8053
def to_python(self, value):
8154
"""
8255
Passes embedded model fields' values through embedded fields
83-
to_python methods and reinstiatates the embedded instance.
84-
85-
We expect to receive a field.attname => value dict together
86-
with a model class from back-end database deconversion (which
87-
needs to know fields of the model beforehand).
56+
to_python() and reinstiatates the embedded instance.
8857
"""
89-
# Either the model class has already been determined during
90-
# deconverting values from the database or we've got a dict
91-
# from a deserializer that may contain model class info.
92-
if isinstance(value, tuple):
93-
embedded_model, attribute_values = value
94-
elif isinstance(value, dict):
95-
embedded_model = self.stored_model(value)
96-
attribute_values = value
97-
else:
58+
if value is None:
59+
return None
60+
if not isinstance(value, dict):
9861
return value
9962
# Create the model instance.
100-
instance = embedded_model(
63+
instance = self.embedded_model(
10164
**{
10265
# Pass values through respective fields' to_python(), leaving
10366
# fields for which no value is specified uninitialized.
104-
field.attname: field.to_python(attribute_values[field.attname])
105-
for field in embedded_model._meta.fields
106-
if field.attname in attribute_values
67+
field.attname: field.to_python(value[field.attname])
68+
for field in self.embedded_model._meta.fields
69+
if field.attname in value
10770
}
10871
)
10972
instance._state.adding = False
@@ -115,21 +78,17 @@ def get_db_prep_save(self, embedded_instance, connection):
11578
fields and passes a field => value mapping down to database
11679
type conversions.
11780
118-
The embedded instance will be saved as a column => value dict
119-
in the end (possibly augmented with info about instance's model
120-
for untyped embedding), but because we need to apply database
121-
type conversions on embedded instance fields' values and for
122-
these we need to know fields those values come from, we need to
123-
entrust the database layer with creating the dict.
81+
The embedded instance will be saved as a column => value dict, but
82+
because we need to apply database type conversions on embedded instance
83+
fields' values and for these we need to know fields those values come
84+
from, we need to entrust the database layer with creating the dict.
12485
"""
12586
if embedded_instance is None:
12687
return None
127-
# The field's value should be an instance of the model given in
128-
# its declaration or at least of some model.
129-
embedded_model = self.embedded_model or models.Model
130-
if not isinstance(embedded_instance, embedded_model):
88+
if not isinstance(embedded_instance, self.embedded_model):
13189
raise TypeError(
132-
f"Expected instance of type {embedded_model!r}, not {type(embedded_instance)!r}."
90+
f"Expected instance of type {self.embedded_model!r}, not "
91+
f"{type(embedded_instance)!r}."
13392
)
13493
# Apply pre_save() and get_db_prep_save() of embedded instance
13594
# fields, create the field => value mapping to be passed to
@@ -144,14 +103,6 @@ def get_db_prep_save(self, embedded_instance, connection):
144103
if field.primary_key and value is None:
145104
continue
146105
field_values[field.attname] = value
147-
if self.embedded_model is None:
148-
# Untyped fields must store model info alongside values.
149-
field_values.update(
150-
(
151-
("_module", embedded_instance.__class__.__module__),
152-
("_model", embedded_instance.__class__.__name__),
153-
)
154-
)
155106
# This instance will exist in the database soon.
156107
# TODO.XXX: Ensure that this doesn't cause race conditions.
157108
embedded_instance._state.adding = False

tests/model_fields_/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ class DecimalParent(models.Model):
2121

2222
class EmbeddedModelFieldModel(models.Model):
2323
simple = EmbeddedModelField("EmbeddedModel", null=True, blank=True)
24-
untyped = EmbeddedModelField(null=True, blank=True)
2524
decimal_parent = EmbeddedModelField(DecimalParent, null=True, blank=True)
2625

2726

tests/model_fields_/test_embedded_model.py

Lines changed: 33 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from decimal import Decimal
33

44
from django.core.exceptions import ValidationError
5-
from django.db import models
65
from django.test import SimpleTestCase, TestCase
76

87
from django_mongodb.fields import EmbeddedModelField
@@ -18,98 +17,74 @@
1817

1918
class MethodTests(SimpleTestCase):
2019
def test_deconstruct(self):
21-
field = EmbeddedModelField()
22-
name, path, args, kwargs = field.deconstruct()
23-
self.assertEqual(path, "django_mongodb.fields.EmbeddedModelField")
24-
self.assertEqual(args, [])
25-
self.assertEqual(kwargs, {})
26-
27-
def test_deconstruct_with_model(self):
2820
field = EmbeddedModelField("EmbeddedModel", null=True)
2921
name, path, args, kwargs = field.deconstruct()
3022
self.assertEqual(path, "django_mongodb.fields.EmbeddedModelField")
3123
self.assertEqual(args, [])
3224
self.assertEqual(kwargs, {"embedded_model": "EmbeddedModel", "null": True})
3325

3426
def test_validate(self):
35-
instance = EmbeddedModelFieldModel(simple=EmbeddedModel(someint=None))
27+
obj = EmbeddedModelFieldModel(simple=EmbeddedModel(someint=None))
3628
# This isn't quite right because "someint" is the field that's non-null.
3729
msg = "{'simple': ['This field cannot be null.']}"
3830
with self.assertRaisesMessage(ValidationError, msg):
39-
instance.full_clean()
31+
obj.full_clean()
4032

4133

4234
class QueryingTests(TestCase):
4335
def assertEqualDatetime(self, d1, d2):
44-
"""Compares d1 and d2, ignoring microseconds."""
36+
"""Compare d1 and d2, ignoring microseconds."""
4537
self.assertEqual(d1.replace(microsecond=0), d2.replace(microsecond=0))
4638

4739
def assertNotEqualDatetime(self, d1, d2):
4840
self.assertNotEqual(d1.replace(microsecond=0), d2.replace(microsecond=0))
4941

5042
def test_save_load(self):
5143
EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint="5"))
52-
instance = EmbeddedModelFieldModel.objects.get()
53-
self.assertIsInstance(instance.simple, EmbeddedModel)
44+
obj = EmbeddedModelFieldModel.objects.get()
45+
self.assertIsInstance(obj.simple, EmbeddedModel)
5446
# Make sure get_prep_value is called.
55-
self.assertEqual(instance.simple.someint, 5)
47+
self.assertEqual(obj.simple.someint, 5)
5648
# Primary keys should not be populated...
57-
self.assertEqual(instance.simple.id, None)
49+
self.assertEqual(obj.simple.id, None)
5850
# ... unless set explicitly.
59-
instance.simple.id = instance.id
60-
instance.save()
61-
instance = EmbeddedModelFieldModel.objects.get()
62-
self.assertEqual(instance.simple.id, instance.id)
51+
obj.simple.id = obj.id
52+
obj.save()
53+
obj = EmbeddedModelFieldModel.objects.get()
54+
self.assertEqual(obj.simple.id, obj.id)
6355

64-
def test_save_load_untyped(self):
65-
EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint="5"))
66-
instance = EmbeddedModelFieldModel.objects.get()
67-
self.assertIsInstance(instance.simple, EmbeddedModel)
68-
# Make sure get_prep_value is called.
69-
self.assertEqual(instance.simple.someint, 5)
70-
# Primary keys should not be populated...
71-
self.assertEqual(instance.simple.id, None)
72-
# ... unless set explicitly.
73-
instance.simple.id = instance.id
74-
instance.save()
75-
instance = EmbeddedModelFieldModel.objects.get()
76-
self.assertEqual(instance.simple.id, instance.id)
56+
def test_save_load_null(self):
57+
EmbeddedModelFieldModel.objects.create(simple=None)
58+
obj = EmbeddedModelFieldModel.objects.get()
59+
self.assertIsNone(obj.simple)
7760

78-
def _test_pre_save(self, instance, get_field):
79-
# Field.pre_save() is called on embedded model fields.
61+
def test_pre_save(self):
62+
"""Field.pre_save() is called on embedded model fields."""
63+
obj = EmbeddedModelFieldModel(simple=EmbeddedModel())
8064

81-
instance.save()
82-
auto_now = get_field(instance).auto_now
83-
auto_now_add = get_field(instance).auto_now_add
65+
obj.save()
66+
auto_now = obj.simple.auto_now
67+
auto_now_add = obj.simple.auto_now_add
8468
self.assertNotEqual(auto_now, None)
8569
self.assertNotEqual(auto_now_add, None)
8670

8771
time.sleep(1) # FIXME
88-
instance.save()
89-
self.assertNotEqualDatetime(get_field(instance).auto_now, get_field(instance).auto_now_add)
72+
obj.save()
73+
self.assertNotEqualDatetime(obj.simple.auto_now, obj.simple.auto_now_add)
9074

91-
instance = EmbeddedModelFieldModel.objects.get()
92-
instance.save()
75+
obj = EmbeddedModelFieldModel.objects.get()
76+
obj.save()
9377
# auto_now_add shouldn't have changed now, but auto_now should.
94-
self.assertEqualDatetime(get_field(instance).auto_now_add, auto_now_add)
95-
self.assertGreater(get_field(instance).auto_now, auto_now)
96-
97-
def test_pre_save(self):
98-
obj = EmbeddedModelFieldModel(simple=EmbeddedModel())
99-
self._test_pre_save(obj, lambda instance: instance.simple)
100-
101-
def test_pre_save_untyped(self):
102-
obj = EmbeddedModelFieldModel(untyped=EmbeddedModel())
103-
self._test_pre_save(obj, lambda instance: instance.untyped)
78+
self.assertEqualDatetime(obj.simple.auto_now_add, auto_now_add)
79+
self.assertGreater(obj.simple.auto_now, auto_now)
10480

10581
def test_error_messages(self):
106-
for model_kwargs, expected in (
107-
({"simple": 42}, EmbeddedModel),
108-
({"untyped": 42}, models.Model),
109-
):
110-
msg = "Expected instance of type %r" % expected
111-
with self.assertRaisesMessage(TypeError, msg):
112-
EmbeddedModelFieldModel(**model_kwargs).save()
82+
msg = (
83+
"Expected instance of type <class 'model_fields_.models.EmbeddedModel'>, "
84+
"not <class 'int'>."
85+
)
86+
with self.assertRaisesMessage(TypeError, msg):
87+
EmbeddedModelFieldModel(simple=42).save()
11388

11489
def test_foreign_key_in_embedded_object(self):
11590
simple = EmbeddedModel(some_relation=Target.objects.create(index=1))

0 commit comments

Comments
 (0)