Skip to content

Commit a50db2b

Browse files
committed
fix #102
1 parent 477ce54 commit a50db2b

File tree

5 files changed

+146
-17
lines changed

5 files changed

+146
-17
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ rest = ["djangorestframework>=3.9,<4.0"]
7777

7878
[dependency-groups]
7979
dev = [
80+
"doc8>=1.1.2",
8081
"beautifulsoup4>=4.13.3",
8182
"coverage>=7.6.12",
8283
"darglint>=1.8.1",
@@ -104,7 +105,6 @@ dev = [
104105
"typing-extensions>=4.12.2",
105106
]
106107
docs = [
107-
"doc8>=1.1.2",
108108
"docutils>=0.21.2",
109109
"furo>=2024.8.6",
110110
"readme-renderer[md]>=44.0",

src/django_enum/fields.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -673,17 +673,20 @@ def formfield(self, form_class=None, choices_form_class=None, **kwargs):
673673
)
674674

675675
is_multi = self.enum and issubclass(self.enum, Flag)
676-
if is_multi and self.enum:
676+
if is_multi:
677677
kwargs["empty_value"] = self.enum(0)
678678
# why fail? - does this fail for single select too?
679679
# kwargs['show_hidden_initial'] = True
680680

681681
if not self.strict:
682682
kwargs.setdefault(
683-
"widget", NonStrictSelectMultiple if is_multi else NonStrictSelect
683+
"widget",
684+
NonStrictSelectMultiple(enum=self.enum)
685+
if is_multi
686+
else NonStrictSelect,
684687
)
685688
elif is_multi:
686-
kwargs.setdefault("widget", FlagSelectMultiple)
689+
kwargs.setdefault("widget", FlagSelectMultiple(enum=self.enum))
687690

688691
form_field = super().formfield(
689692
form_class=form_class,
@@ -1217,15 +1220,6 @@ def contribute_to_class(
12171220
# for non flag fields
12181221
IntegerField.contribute_to_class(self, cls, name, private_only=private_only)
12191222

1220-
def _coerce_to_value_type(self, value: Any) -> Any:
1221-
if (
1222-
isinstance(value, list)
1223-
or isinstance(value, tuple)
1224-
or isinstance(value, set)
1225-
):
1226-
value = reduce(or_, value)
1227-
return super()._coerce_to_value_type(value)
1228-
12291223

12301224
class SmallIntegerFlagField(FlagField, EnumPositiveSmallIntegerField):
12311225
"""

src/django_enum/forms.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from copy import copy
44
from decimal import DecimalException
55
from enum import Enum, Flag
6+
from functools import reduce
7+
from operator import or_
68
from typing import Any, Iterable, List, Optional, Protocol, Sequence, Tuple, Type, Union
79

810
from django.core.exceptions import ValidationError
@@ -85,8 +87,32 @@ class FlagSelectMultiple(SelectMultiple):
8587
A SelectMultiple widget for EnumFlagFields.
8688
"""
8789

90+
enum: Optional[Type[Flag]]
8891

89-
class NonStrictSelectMultiple(NonStrictMixin, SelectMultiple):
92+
def __init__(self, enum: Optional[Type[Flag]] = None, **kwargs):
93+
self.enum = enum
94+
super().__init__(**kwargs)
95+
96+
def format_value(self, value):
97+
"""
98+
Return a list of the flag's values.
99+
"""
100+
if not isinstance(value, list):
101+
# see impl of ChoiceWidget.optgroups
102+
# it compares the string conversion of the value of each
103+
# choice tuple to the string conversion of the value
104+
# to determine selected options
105+
if self.enum:
106+
return [str(en.value) for en in self.enum(value)]
107+
if isinstance(value, int):
108+
# automagically work for IntFlags even if we weren't given the enum
109+
return [
110+
str(1 << i) for i in range(value.bit_length()) if (value >> i) & 1
111+
]
112+
return value
113+
114+
115+
class NonStrictSelectMultiple(NonStrictMixin, FlagSelectMultiple):
90116
"""
91117
A SelectMultiple widget for non-strict EnumFlagFields that includes any
92118
existing non-conforming value as a choice option.
@@ -314,6 +340,8 @@ class EnumFlagField(ChoiceFieldMixin, TypedMultipleChoiceField): # type: ignore
314340
if strict=False, values can be outside of the enumerations
315341
"""
316342

343+
widget = FlagSelectMultiple
344+
317345
def __init__(
318346
self,
319347
enum: Optional[Type[Flag]] = None,
@@ -324,6 +352,10 @@ def __init__(
324352
choices: _ChoicesParameter = (),
325353
**kwargs,
326354
):
355+
kwargs.setdefault(
356+
"widget",
357+
self.widget(enum=enum) if strict else NonStrictSelectMultiple(enum=enum),
358+
)
327359
super().__init__(
328360
enum=enum,
329361
empty_value=(
@@ -334,3 +366,10 @@ def __init__(
334366
choices=choices,
335367
**kwargs,
336368
)
369+
370+
def _coerce(self, value: Any) -> Any:
371+
"""Combine the values into a single flag using |"""
372+
values = TypedMultipleChoiceField._coerce(self, value) # type: ignore[attr-defined]
373+
if values:
374+
return reduce(or_, values)
375+
return self.empty_value

tests/test_forms_ep.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
pytest.importorskip("enum_properties")
44
from tests.test_forms import FormTests, TestFormField
5-
from tests.enum_prop.models import EnumTester
5+
from tests.enum_prop.models import EnumTester, BitFieldModel
66
from tests.enum_prop.forms import EnumTesterForm
7+
from tests.examples.models import FlagExample
8+
from django_enum.forms import EnumFlagField, FlagSelectMultiple
9+
from django.forms import ModelForm
710

811

912
class EnumPropertiesFormTests(FormTests):
@@ -34,6 +37,99 @@ def model_params(self):
3437
"no_coerce": "Value 1",
3538
}
3639

40+
def test_flag_choices_admin_form(self):
41+
from django.contrib import admin
42+
43+
admin_class = admin.site._registry.get(BitFieldModel)
44+
self.assertIsInstance(
45+
admin_class.get_form(None).base_fields.get("bit_field_small"), EnumFlagField
46+
)
47+
48+
def test_flag_choices_model_form(self):
49+
from tests.examples.models.flag import Permissions
50+
from tests.enum_prop.enums import GNSSConstellation
51+
52+
class FlagChoicesModelForm(ModelForm):
53+
class Meta(EnumTesterForm.Meta):
54+
model = BitFieldModel
55+
56+
form = FlagChoicesModelForm(
57+
data={"bit_field_small": [GNSSConstellation.GPS, GNSSConstellation.GLONASS]}
58+
)
59+
60+
form.full_clean()
61+
self.assertTrue(form.is_valid())
62+
self.assertEqual(
63+
form.cleaned_data["bit_field_small"],
64+
GNSSConstellation.GPS | GNSSConstellation.GLONASS,
65+
)
66+
self.assertIsInstance(form.base_fields["bit_field_small"], EnumFlagField)
67+
68+
def test_extern_flag_admin_form(self):
69+
from django.contrib import admin
70+
71+
admin_class = admin.site._registry.get(FlagExample)
72+
self.assertIsInstance(
73+
admin_class.get_form(None).base_fields.get("permissions"), EnumFlagField
74+
)
75+
76+
def test_extern_flag_model_form(self):
77+
from tests.examples.models.flag import Permissions
78+
79+
class FlagModelForm(ModelForm):
80+
class Meta(EnumTesterForm.Meta):
81+
model = FlagExample
82+
83+
form = FlagModelForm(
84+
data={"permissions": [Permissions.READ, Permissions.WRITE]}
85+
)
86+
87+
form.full_clean()
88+
self.assertTrue(form.is_valid())
89+
self.assertEqual(
90+
form.cleaned_data["permissions"], Permissions.READ | Permissions.WRITE
91+
)
92+
self.assertIsInstance(form.base_fields["permissions"], EnumFlagField)
93+
94+
def test_flag_select_multiple_format(self):
95+
from tests.examples.models.flag import Permissions
96+
97+
widget = FlagSelectMultiple() # no enum
98+
self.assertEqual(
99+
widget.format_value(Permissions.READ | Permissions.WRITE),
100+
[str(Permissions.READ.value), str(Permissions.WRITE.value)],
101+
)
102+
self.assertEqual(
103+
widget.format_value(Permissions.READ | Permissions.EXECUTE),
104+
[str(Permissions.READ.value), str(Permissions.EXECUTE.value)],
105+
)
106+
self.assertEqual(
107+
widget.format_value(Permissions.EXECUTE | Permissions.WRITE),
108+
[str(Permissions.WRITE.value), str(Permissions.EXECUTE.value)],
109+
)
110+
111+
widget = FlagSelectMultiple(enum=Permissions) # no enum
112+
self.assertEqual(
113+
widget.format_value(Permissions.READ | Permissions.WRITE),
114+
[str(Permissions.READ.value), str(Permissions.WRITE.value)],
115+
)
116+
self.assertEqual(
117+
widget.format_value(Permissions.READ | Permissions.EXECUTE),
118+
[str(Permissions.READ.value), str(Permissions.EXECUTE.value)],
119+
)
120+
self.assertEqual(
121+
widget.format_value(Permissions.EXECUTE | Permissions.WRITE),
122+
[str(Permissions.WRITE.value), str(Permissions.EXECUTE.value)],
123+
)
124+
125+
# check pass through
126+
self.assertEqual(
127+
widget.format_value(
128+
[str(Permissions.WRITE.value), str(Permissions.EXECUTE.value)]
129+
),
130+
[str(Permissions.WRITE.value), str(Permissions.EXECUTE.value)],
131+
)
132+
37133

38134
FormTests = None
39135
TestFormField = None

uv.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)