Skip to content

Commit 06720b8

Browse files
authored
Allow to narrow down choices in an Enum field (#31)
1 parent 801ef61 commit 06720b8

File tree

3 files changed

+39
-8
lines changed

3 files changed

+39
-8
lines changed

cleancat/base.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import inspect
23
import re
34
import sys
45

@@ -574,12 +575,31 @@ def clean(self, value):
574575
class Enum(Choices):
575576
"""Like Choices, but expects a Python 3 Enum."""
576577

578+
def __init__(self, choices, **kwargs):
579+
"""Initialize the Enum field.
580+
581+
The `choices` param can be either:
582+
* an enum.Enum class (in which case all of its values will become
583+
valid choices),
584+
* a list containing a subset of the enum's choices (e.g.
585+
`[SomeEnumCls.OptionA, SomeEnumCls.OptionB]`). You must provide
586+
more than one choice in this list and *all* of the choices must
587+
belong to the same enum class.
588+
"""
589+
is_cls = inspect.isclass(choices)
590+
if is_cls:
591+
self.enum_cls = choices
592+
else:
593+
assert choices, 'You need to provide at least one enum choice.'
594+
self.enum_cls = choices[0].__class__
595+
return super(Enum, self).__init__(choices, **kwargs)
596+
577597
def get_choices(self):
578598
return [choice.value for choice in self.choices]
579599

580600
def clean(self, value):
581601
value = super(Enum, self).clean(value)
582-
return self.choices(value)
602+
return self.enum_cls(value)
583603

584604
def serialize(self, choice):
585605
if choice is not None:

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from setuptools import setup
23

34
install_requirements = [
@@ -11,6 +12,9 @@
1112
'sqlalchemy'
1213
]
1314

15+
if sys.version_info[:2] < (3, 4):
16+
test_requirements += ['enum34']
17+
1418
setup(
1519
name='cleancat',
1620
version='0.7.3',

tests/__init__.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22
import datetime
3+
import enum
34
import re
4-
import sys
55
import unittest
66

77
import pytest
@@ -319,10 +319,7 @@ class CaseInsensitiveChoiceSchema(Schema):
319319
self.assertInvalid(CaseInsensitiveChoiceSchema({'choice': 'world '}), {'field-errors': ['choice']})
320320
self.assertInvalid(CaseInsensitiveChoiceSchema({'choice': 'invalid'}), {'field-errors': ['choice']})
321321

322-
@unittest.skipIf(sys.version_info < (3, 4), 'enum unavailable')
323322
def test_enum(self):
324-
import enum
325-
326323
class MyChoices(enum.Enum):
327324
A = 'a'
328325
B = 'b'
@@ -334,6 +331,19 @@ class ChoiceSchema(Schema):
334331
self.assertValid(ChoiceSchema({'choice': 'b'}), {'choice': MyChoices.B})
335332
self.assertInvalid(ChoiceSchema({'choice': 'c'}), {'field-errors': ['choice']})
336333

334+
def test_enum_with_choices(self):
335+
class MyChoices(enum.Enum):
336+
A = 'a'
337+
B = 'b'
338+
C = 'c'
339+
340+
class ChoiceSchema(Schema):
341+
choice = Enum([MyChoices.A, MyChoices.B])
342+
343+
self.assertValid(ChoiceSchema({'choice': 'a'}), {'choice': MyChoices.A})
344+
self.assertValid(ChoiceSchema({'choice': 'b'}), {'choice': MyChoices.B})
345+
self.assertInvalid(ChoiceSchema({'choice': 'c'}), {'field-errors': ['choice']})
346+
337347
def test_url(self):
338348
class URLSchema(Schema):
339349
url = URL()
@@ -658,10 +668,7 @@ class TestSchema(Schema):
658668
'dictionary': {},
659669
}
660670

661-
@unittest.skipIf(sys.version_info < (3, 4), 'enum unavailable')
662671
def test_serialization_enum(self):
663-
import enum
664-
665672
class MyChoices(enum.Enum):
666673
A = 'a'
667674
B = 'b'

0 commit comments

Comments
 (0)