Skip to content

Commit e2947b8

Browse files
committed
wip forms support
1 parent a33ee4e commit e2947b8

File tree

6 files changed

+234
-0
lines changed

6 files changed

+234
-0
lines changed

django_mongodb/fields/embedded_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from django.db.models.fields.related import lazy_related_operation
33
from django.db.models.lookups import Transform
44

5+
from .. import forms
6+
57

68
class EmbeddedModelField(models.Field):
79
"""Field that stores a model instance."""
@@ -123,6 +125,16 @@ def validate(self, value, model_instance):
123125
attname = field.attname
124126
field.validate(getattr(value, attname), model_instance)
125127

128+
def formfield(self, **kwargs):
129+
return super().formfield(
130+
**{
131+
"form_class": forms.EmbeddedModelFormField,
132+
"model": self.embedded_model,
133+
"name": self.name,
134+
**kwargs,
135+
}
136+
)
137+
126138

127139
class KeyTransform(Transform):
128140
def __init__(self, key_name, *args, **kwargs):

django_mongodb/forms.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from django import forms
2+
from django.forms.models import modelform_factory
3+
from django.utils.safestring import mark_safe
4+
from django.utils.translation import gettext_lazy as _
5+
6+
7+
class EmbeddedModelWidget(forms.MultiWidget):
8+
def __init__(self, field_names, *args, **kwargs):
9+
self.field_names = field_names
10+
super().__init__(*args, **kwargs)
11+
# The default widget names are "_0", "_1", etc. Use the field names
12+
# instead since that's how they'll be rendered by the model form.
13+
self.widgets_names = ["-" + name for name in field_names]
14+
15+
def decompress(self, value):
16+
if value is None:
17+
return []
18+
# Get the data from `value` (a model) for each field.
19+
return [getattr(value, name) for name in self.field_names]
20+
21+
22+
class EmbeddedModelBoundField(forms.BoundField):
23+
def __str__(self):
24+
"""Render the model form as the representation for this field."""
25+
form = self.field.model_form_cls(instance=self.value(), **self.field.form_kwargs)
26+
return mark_safe(f"{form.as_div()}") # noqa: S308
27+
28+
29+
class EmbeddedModelFormField(forms.MultiValueField):
30+
default_error_messages = {
31+
"invalid": _("Enter a list of values."),
32+
"incomplete": _("Enter all required values."),
33+
}
34+
35+
def __init__(self, model, name, *args, **kwargs):
36+
form_kwargs = {}
37+
# The field must be prefixed with the name of the field.
38+
form_kwargs["prefix"] = name
39+
self.form_kwargs = form_kwargs
40+
self.model_form_cls = modelform_factory(model, fields="__all__")
41+
self.model_form = self.model_form_cls(**form_kwargs)
42+
self.field_names = list(self.model_form.fields.keys())
43+
fields = self.model_form.fields.values()
44+
widgets = [field.widget for field in fields]
45+
widget = EmbeddedModelWidget(self.field_names, widgets)
46+
super().__init__(*args, fields=fields, widget=widget, require_all_fields=False, **kwargs)
47+
48+
def compress(self, data_dict):
49+
if not data_dict:
50+
return None
51+
values = dict(zip(self.field_names, data_dict, strict=False))
52+
return self.model_form._meta.model(**values)
53+
54+
def get_bound_field(self, form, field_name):
55+
return EmbeddedModelBoundField(form, self, field_name)
56+
57+
def bound_data(self, data, initial):
58+
if self.disabled:
59+
return initial
60+
# The bound data must be transformed into a model instance.
61+
return self.compress(data)

tests/model_forms_/__init__.py

Whitespace-only changes.

tests/model_forms_/forms.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from django import forms
2+
3+
from .models import Author
4+
5+
6+
class AuthorForm(forms.ModelForm):
7+
class Meta:
8+
fields = "__all__"
9+
model = Author

tests/model_forms_/models.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from django.db import models
2+
3+
from django_mongodb.fields import EmbeddedModelField
4+
5+
6+
class Address(models.Model):
7+
po_box = models.CharField(max_length=50, blank=True, verbose_name="PO Box")
8+
city = models.CharField(max_length=20)
9+
state = models.CharField(max_length=2)
10+
zip_code = models.IntegerField()
11+
12+
13+
class Author(models.Model):
14+
name = models.CharField(max_length=10)
15+
age = models.IntegerField()
16+
address = EmbeddedModelField(Address)
17+
billing_address = EmbeddedModelField(Address, blank=True, null=True)
18+
19+
20+
class Book(models.Model):
21+
name = models.CharField(max_length=100)
22+
author = EmbeddedModelField(Author)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from django.test import TestCase
2+
3+
from .forms import AuthorForm
4+
from .models import Address, Author
5+
6+
7+
class ModelFormTests(TestCase):
8+
def test_update(self):
9+
author = Author.objects.create(
10+
name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001")
11+
)
12+
data = {
13+
"name": "Bob",
14+
"age": 51,
15+
"address-po_box": "",
16+
"address-city": "New York City",
17+
"address-state": "NY",
18+
"address-zip_code": "10001",
19+
}
20+
form = AuthorForm(data, instance=author)
21+
self.assertTrue(form.is_valid())
22+
form.save()
23+
author.refresh_from_db()
24+
self.assertEqual(author.age, 51)
25+
self.assertEqual(author.address.city, "New York City")
26+
27+
def test_some_missing_data(self):
28+
author = Author.objects.create(
29+
name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001")
30+
)
31+
data = {
32+
"name": "Bob",
33+
"age": 51,
34+
"address-po_box": "",
35+
"address-city": "New York City",
36+
"address-state": "NY",
37+
"address-zip_code": "",
38+
}
39+
form = AuthorForm(data, instance=author)
40+
self.assertFalse(form.is_valid())
41+
self.assertEqual(form.errors["address"], ["Enter all required values."])
42+
43+
def test_invalid_field_data(self):
44+
"""A field's data (state) is too long."""
45+
author = Author.objects.create(
46+
name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001")
47+
)
48+
data = {
49+
"name": "Bob",
50+
"age": 51,
51+
"address-po_box": "",
52+
"address-city": "New York City",
53+
"address-state": "TOO LONG",
54+
"address-zip_code": "",
55+
}
56+
form = AuthorForm(data, instance=author)
57+
self.assertFalse(form.is_valid())
58+
self.assertEqual(
59+
form.errors["address"],
60+
[
61+
"Ensure this value has at most 2 characters (it has 8).",
62+
"Enter all required values.",
63+
],
64+
)
65+
66+
def test_all_missing_data(self):
67+
author = Author.objects.create(
68+
name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001")
69+
)
70+
data = {
71+
"name": "Bob",
72+
"age": 51,
73+
"address-po_box": "",
74+
"address-city": "",
75+
"address-state": "",
76+
"address-zip_code": "",
77+
}
78+
form = AuthorForm(data, instance=author)
79+
self.assertFalse(form.is_valid())
80+
self.assertEqual(form.errors["address"], ["This field is required."])
81+
82+
def test_nullable_field(self):
83+
"""A nullable EmbeddedModelField is removed if all fields are empty."""
84+
author = Author.objects.create(
85+
name="Bob",
86+
age=50,
87+
address=Address(city="NYC", state="NY", zip_code="10001"),
88+
billing_address=Address(city="NYC", state="NY", zip_code="10001"),
89+
)
90+
data = {
91+
"name": "Bob",
92+
"age": 51,
93+
"address-po_box": "",
94+
"address-city": "New York City",
95+
"address-state": "NY",
96+
"address-zip_code": "10001",
97+
"billing_address-po_box": "",
98+
"billing_address-city": "",
99+
"billing_address-state": "",
100+
"billing_address-zip_code": "",
101+
}
102+
form = AuthorForm(data, instance=author)
103+
self.assertTrue(form.is_valid())
104+
form.save()
105+
author.refresh_from_db()
106+
self.assertIsNone(author.billing_address)
107+
108+
def test_rendering(self):
109+
form = AuthorForm()
110+
self.assertHTMLEqual(
111+
str(form.fields["address"].get_bound_field(form, "address")),
112+
"""
113+
<div>
114+
<label for="id_address-po_box">PO Box:</label>
115+
<input id="id_address-po_box" maxlength="50" name="address-po_box" type="text">
116+
</div>
117+
<div>
118+
<label for="id_address-city">City:</label>
119+
<input type="text" name="address-city" maxlength="20" required id="id_address-city">
120+
</div>
121+
<div>
122+
<label for="id_address-state">State:</label>
123+
<input type="text" name="address-state" maxlength="2" required
124+
id="id_address-state">
125+
</div>
126+
<div>
127+
<label for="id_address-zip_code">Zip code:</label>
128+
<input type="number" name="address-zip_code" required id="id_address-zip_code">
129+
</div>""",
130+
)

0 commit comments

Comments
 (0)