Skip to content

Commit bc1033c

Browse files
committed
prohibit embedded relational fields
1 parent 530d1e2 commit bc1033c

File tree

3 files changed

+35
-28
lines changed

3 files changed

+35
-28
lines changed

django_mongodb_backend/fields/embedded_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from django.core import checks
12
from django.db import models
23
from django.db.models.fields.related import lazy_related_operation
34
from django.db.models.lookups import Transform
@@ -17,6 +18,19 @@ def __init__(self, embedded_model, *args, **kwargs):
1718
self.embedded_model = embedded_model
1819
super().__init__(*args, **kwargs)
1920

21+
def check(self, **kwargs):
22+
errors = super().check(**kwargs)
23+
for field in self.embedded_model._meta.fields:
24+
if field.remote_field:
25+
errors.append(
26+
checks.Error(
27+
"Embedded models cannot have relational fields.",
28+
obj=self,
29+
id="django_mongodb.embedded_model.E001",
30+
)
31+
)
32+
return errors
33+
2034
def deconstruct(self):
2135
name, path, args, kwargs = super().deconstruct()
2236
if path.startswith("django_mongodb_backend.fields.embedded_model"):

tests/model_fields_/models.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,6 @@ class ArrayEnumModel(models.Model):
9494

9595

9696
# EmbeddedModelField
97-
class Target(models.Model):
98-
index = models.IntegerField()
99-
100-
10197
class DecimalModel(models.Model):
10298
decimal = models.DecimalField(max_digits=9, decimal_places=2)
10399

@@ -106,17 +102,12 @@ class DecimalKey(models.Model):
106102
decimal = models.DecimalField(max_digits=9, decimal_places=2, primary_key=True)
107103

108104

109-
class DecimalParent(models.Model):
110-
child = models.ForeignKey(DecimalKey, models.CASCADE)
111-
112-
113105
class EmbeddedModelFieldModel(models.Model):
114106
simple = EmbeddedModelField("EmbeddedModel", null=True, blank=True)
115-
decimal_parent = EmbeddedModelField(DecimalParent, null=True, blank=True)
107+
decimal_parent = EmbeddedModelField(DecimalKey, null=True, blank=True)
116108

117109

118110
class EmbeddedModel(models.Model):
119-
some_relation = models.ForeignKey(Target, models.CASCADE, null=True, blank=True)
120111
someint = models.IntegerField(db_column="custom_column")
121112
auto_now = models.DateTimeField(auto_now=True)
122113
auto_now_add = models.DateTimeField(auto_now_add=True)

tests/model_fields_/test_embedded_model.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
1-
from decimal import Decimal
2-
31
from django.core.exceptions import ValidationError
2+
from django.db import models
43
from django.test import SimpleTestCase, TestCase
4+
from django.test.utils import isolate_apps
55

66
from django_mongodb_backend.fields import EmbeddedModelField
77

88
from .models import (
99
Address,
1010
Author,
1111
Book,
12-
DecimalKey,
13-
DecimalParent,
1412
EmbeddedModel,
1513
EmbeddedModelFieldModel,
16-
Target,
1714
)
1815

1916

@@ -82,19 +79,6 @@ def test_pre_save(self):
8279
self.assertEqual(obj.simple.auto_now_add, auto_now_add)
8380
self.assertGreater(obj.simple.auto_now, auto_now_two)
8481

85-
def test_foreign_key_in_embedded_object(self):
86-
simple = EmbeddedModel(some_relation=Target.objects.create(index=1))
87-
obj = EmbeddedModelFieldModel.objects.create(simple=simple)
88-
simple = EmbeddedModelFieldModel.objects.get().simple
89-
self.assertNotIn("some_relation", simple.__dict__)
90-
self.assertIsInstance(simple.__dict__["some_relation_id"], type(obj.id))
91-
self.assertIsInstance(simple.some_relation, Target)
92-
93-
def test_embedded_field_with_foreign_conversion(self):
94-
decimal = DecimalKey.objects.create(decimal=Decimal("1.5"))
95-
decimal_parent = DecimalParent.objects.create(child=decimal)
96-
EmbeddedModelFieldModel.objects.create(decimal_parent=decimal_parent)
97-
9882

9983
class QueryingTests(TestCase):
10084
@classmethod
@@ -134,3 +118,21 @@ def test_nested(self):
134118
author=Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY"))
135119
)
136120
self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj])
121+
122+
123+
@isolate_apps("model_fields_")
124+
class CheckTests(SimpleTestCase):
125+
def test_no_relational_fields(self):
126+
class Target(models.Model):
127+
key = models.ForeignKey("MyModel", models.CASCADE)
128+
129+
class MyModel(models.Model):
130+
field = EmbeddedModelField(Target)
131+
132+
model = MyModel()
133+
errors = model.check()
134+
self.assertEqual(len(errors), 1)
135+
# The inner CharField has a non-positive max_length.
136+
self.assertEqual(errors[0].id, "django_mongodb.embedded_model.E001")
137+
msg = errors[0].msg
138+
self.assertEqual(msg, "Embedded models cannot have relational fields.")

0 commit comments

Comments
 (0)