Skip to content

Commit 88ddb4f

Browse files
committed
moved _add_* methods to Enum; fixed unhashable value handling
1 parent b771360 commit 88ddb4f

File tree

3 files changed

+125
-66
lines changed

3 files changed

+125
-66
lines changed

Doc/howto/enum.rst

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -961,10 +961,6 @@ all the members are created it is no longer used.
961961
Supported ``_sunder_`` names
962962
""""""""""""""""""""""""""""
963963

964-
- :meth:`~EnumType._add_alias_` -- adds a new name as an alias to an existing
965-
member.
966-
- :meth:`~EnumType._add_value_alias_` -- adds a new value as an alias to an
967-
existing member.
968964
- :attr:`~Enum._name_` -- name of the member
969965
- :attr:`~Enum._value_` -- value of the member; can be set in ``__new__``
970966
- :meth:`~Enum._missing_` -- a lookup function used when a value is not found;
@@ -974,6 +970,10 @@ Supported ``_sunder_`` names
974970
from the final class
975971
- :meth:`~Enum._generate_next_value_` -- used to get an appropriate value for
976972
an enum member; may be overridden
973+
- :meth:`~Enum._add_alias_` -- adds a new name as an alias to an existing
974+
member.
975+
- :meth:`~Enum._add_value_alias_` -- adds a new value as an alias to an
976+
existing member. See `MultiValueEnum`_ for an example.
977977

978978
.. note::
979979

@@ -1451,6 +1451,29 @@ alias::
14511451
disallowing aliases, the :func:`unique` decorator can be used instead.
14521452

14531453

1454+
MultiValueEnum
1455+
^^^^^^^^^^^^^^^^^
1456+
1457+
Supports having more than one value per member::
1458+
1459+
>>> class MultiValueEnum(Enum):
1460+
... def __new__(cls, value, *values):
1461+
... self = object.__new__(cls)
1462+
... self._value_ = value
1463+
... for v in values:
1464+
... self._add_value_alias_(v)
1465+
... return self
1466+
...
1467+
>>> class DType(MultiValueEnum):
1468+
... float32 = 'f', 8
1469+
... double64 = 'd', 9
1470+
...
1471+
>>> DType('f')
1472+
<DType.float32: 'f'>
1473+
>>> DType(9)
1474+
<DType.double64: 'd'>
1475+
1476+
14541477
Planet
14551478
^^^^^^
14561479

Lib/enum.py

Lines changed: 69 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def __set_name__(self, enum_class, member_name):
314314
# no other instances found, record this member in _member_names_
315315
enum_class._member_names_.append(member_name)
316316

317-
enum_class._add_member_(enum_member, member_name)
317+
enum_class._add_member_(member_name, enum_member)
318318
try:
319319
# This may fail if value is not hashable. We can't add the value
320320
# to the map, and by-value lookups for this value will be
@@ -323,6 +323,7 @@ def __set_name__(self, enum_class, member_name):
323323
except TypeError:
324324
# keep track of the value in a list so containment checks are quick
325325
enum_class._unhashable_values_.append(value)
326+
enum_class._unhashable_values_map_.setdefault(member_name, []).append(value)
326327

327328

328329
class EnumDict(dict):
@@ -356,6 +357,7 @@ def __setitem__(self, key, value):
356357
'_order_',
357358
'_generate_next_value_', '_numeric_repr_', '_missing_', '_ignore_',
358359
'_iter_member_', '_iter_member_by_value_', '_iter_member_by_def_',
360+
'_add_alias_', '_add_value_alias_',
359361
):
360362
raise ValueError(
361363
'_sunder_ names, such as %r, are reserved for future Enum use'
@@ -521,6 +523,7 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
521523
classdict['_member_map_'] = {}
522524
classdict['_value2member_map_'] = {}
523525
classdict['_unhashable_values_'] = []
526+
classdict['_unhashable_values_map_'] = {}
524527
classdict['_member_type_'] = member_type
525528
# now set the __repr__ for the value
526529
classdict['_value_repr_'] = metacls._find_data_repr_(cls, bases)
@@ -723,7 +726,10 @@ def __contains__(cls, value):
723726
"""
724727
if isinstance(value, cls):
725728
return True
726-
return value in cls._value2member_map_ or value in cls._unhashable_values_
729+
try:
730+
return value in cls._value2member_map_
731+
except TypeError:
732+
return value in cls._unhashable_values_
727733

728734
def __delattr__(cls, attr):
729735
# nicer error message when someone tries to delete an attribute
@@ -1020,7 +1026,8 @@ def _find_new_(mcls, classdict, member_type, first_enum):
10201026
use_args = True
10211027
return __new__, save_new, use_args
10221028

1023-
def _add_alias_(cls, member, name):
1029+
def _add_member_(cls, name, member):
1030+
# _value_ structures are not updated
10241031
if name in cls._member_map_:
10251032
if cls._member_map_[name] is not member:
10261033
raise NameError('%r is already bound: %r' % (name, cls._member_map_[name]))
@@ -1067,30 +1074,6 @@ def _add_alias_(cls, member, name):
10671074
cls._member_map_[name] = member
10681075
#
10691076
cls._member_map_[name] = member
1070-
_add_member_ = _add_alias_ # use _add_member_ internally
1071-
1072-
def _add_value_alias_(cls, member, value):
1073-
try:
1074-
if value in cls._value2member_map_:
1075-
if cls._value2member_map_[value] is not member:
1076-
raise ValueError('%r is already bound: %r' % (value, cls._value2member_map_[value]))
1077-
return
1078-
except TypeError:
1079-
# unhashable value, do long search
1080-
for m in cls._member_map_.values():
1081-
if m._value_ == value:
1082-
if m is not member:
1083-
raise ValueError('%r is already bound: %r' % (value, cls._value2member_map_[value]))
1084-
return
1085-
try:
1086-
# This may fail if value is not hashable. We can't add the value
1087-
# to the map, and by-value lookups for this value will be
1088-
# linear.
1089-
cls._value2member_map_.setdefault(value, member)
1090-
except TypeError:
1091-
# keep track of the value in a list so containment checks are quick
1092-
cls._unhashable_values_.append(value)
1093-
10941077

10951078
EnumMeta = EnumType # keep EnumMeta name for backwards compatibility
10961079

@@ -1158,9 +1141,9 @@ def __new__(cls, value):
11581141
pass
11591142
except TypeError:
11601143
# not there, now do long search -- O(n) behavior
1161-
for member in cls._member_map_.values():
1162-
if member._value_ is value or member._value_ == value:
1163-
return member
1144+
for name, values in cls._unhashable_values_map_.items():
1145+
if value in values:
1146+
return cls[name]
11641147
# still not found -- verify that members exist, in-case somebody got here mistakenly
11651148
# (such as via super when trying to override __new__)
11661149
if not cls._member_map_:
@@ -1201,6 +1184,33 @@ def __new__(cls, value):
12011184
def __init__(self, *args, **kwds):
12021185
pass
12031186

1187+
def _add_alias_(self, name):
1188+
self.__class__._add_member_(name, self)
1189+
1190+
def _add_value_alias_(self, value):
1191+
cls = self.__class__
1192+
try:
1193+
if value in cls._value2member_map_:
1194+
if cls._value2member_map_[value] is not self:
1195+
raise ValueError('%r is already bound: %r' % (value, cls._value2member_map_[value]))
1196+
return
1197+
except TypeError:
1198+
# unhashable value, do long search
1199+
for m in cls._member_map_.values():
1200+
if m._value_ == value:
1201+
if m is not self:
1202+
raise ValueError('%r is already bound: %r' % (value, cls._value2member_map_[value]))
1203+
return
1204+
try:
1205+
# This may fail if value is not hashable. We can't add the value
1206+
# to the map, and by-value lookups for this value will be
1207+
# linear.
1208+
cls._value2member_map_.setdefault(value, self)
1209+
except TypeError:
1210+
# keep track of the value in a list so containment checks are quick
1211+
cls._unhashable_values_.append(value)
1212+
cls._unhashable_values_map_.setdefault(self.name, []).append(value)
1213+
12041214
@staticmethod
12051215
def _generate_next_value_(name, start, count, last_values):
12061216
"""
@@ -1713,7 +1723,8 @@ def convert_class(cls):
17131723
body['_member_names_'] = member_names = []
17141724
body['_member_map_'] = member_map = {}
17151725
body['_value2member_map_'] = value2member_map = {}
1716-
body['_unhashable_values_'] = []
1726+
body['_unhashable_values_'] = unhashable_values = []
1727+
body['_unhashable_values_map_'] = {}
17171728
body['_member_type_'] = member_type = etype._member_type_
17181729
body['_value_repr_'] = etype._value_repr_
17191730
if issubclass(etype, Flag):
@@ -1760,14 +1771,9 @@ def convert_class(cls):
17601771
for name, value in attrs.items():
17611772
if isinstance(value, auto) and auto.value is _auto_null:
17621773
value = gnv(name, 1, len(member_names), gnv_last_values)
1763-
if value in value2member_map:
1774+
if value in value2member_map or value in unhashable_values:
17641775
# an alias to an existing member
1765-
member = value2member_map[value]
1766-
redirect = property()
1767-
redirect.member = member
1768-
redirect.__set_name__(enum_class, name)
1769-
setattr(enum_class, name, redirect)
1770-
member_map[name] = member
1776+
enum_class(value)._add_alias_(name)
17711777
else:
17721778
# create the member
17731779
if use_args:
@@ -1782,12 +1788,12 @@ def convert_class(cls):
17821788
member._name_ = name
17831789
member.__objclass__ = enum_class
17841790
member.__init__(value)
1785-
redirect = property()
1786-
redirect.member = member
1787-
redirect.__set_name__(enum_class, name)
1788-
setattr(enum_class, name, redirect)
1789-
member_map[name] = member
17901791
member._sort_order_ = len(member_names)
1792+
if name not in ('name', 'value'):
1793+
setattr(enum_class, name, member)
1794+
member_map[name] = member
1795+
else:
1796+
enum_class._add_member_(name, member)
17911797
value2member_map[value] = member
17921798
if _is_single_bit(value):
17931799
# not a multi-bit alias, record in _member_names_ and _flag_mask_
@@ -1810,14 +1816,13 @@ def convert_class(cls):
18101816
if value.value is _auto_null:
18111817
value.value = gnv(name, 1, len(member_names), gnv_last_values)
18121818
value = value.value
1813-
if value in value2member_map:
1819+
try:
1820+
contained = value in value2member_map
1821+
except TypeError:
1822+
contained = value in unhashable_values
1823+
if contained:
18141824
# an alias to an existing member
1815-
member = value2member_map[value]
1816-
redirect = property()
1817-
redirect.member = member
1818-
redirect.__set_name__(enum_class, name)
1819-
setattr(enum_class, name, redirect)
1820-
member_map[name] = member
1825+
enum_class(value)._add_alias_(name)
18211826
else:
18221827
# create the member
18231828
if use_args:
@@ -1833,14 +1838,22 @@ def convert_class(cls):
18331838
member.__objclass__ = enum_class
18341839
member.__init__(value)
18351840
member._sort_order_ = len(member_names)
1836-
redirect = property()
1837-
redirect.member = member
1838-
redirect.__set_name__(enum_class, name)
1839-
setattr(enum_class, name, redirect)
1840-
member_map[name] = member
1841-
value2member_map[value] = member
1841+
if name not in ('name', 'value'):
1842+
setattr(enum_class, name, member)
1843+
member_map[name] = member
1844+
else:
1845+
enum_class._add_member_(name, member)
18421846
member_names.append(name)
18431847
gnv_last_values.append(value)
1848+
try:
1849+
# This may fail if value is not hashable. We can't add the value
1850+
# to the map, and by-value lookups for this value will be
1851+
# linear.
1852+
enum_class._value2member_map_.setdefault(value, member)
1853+
except TypeError:
1854+
# keep track of the value in a list so containment checks are quick
1855+
enum_class._unhashable_values_.append(value)
1856+
enum_class._unhashable_values_map_.setdefault(name, []).append(value)
18441857
if '__new__' in body:
18451858
enum_class.__new_member__ = enum_class.__new__
18461859
enum_class.__new__ = Enum.__new__

Lib/test/test_enum.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ def test_contains_tf(self):
514514
self.assertFalse('first' in MainEnum)
515515
val = MainEnum.dupe
516516
self.assertIn(val, MainEnum)
517+
self.assertNotIn(float('nan'), MainEnum)
517518
#
518519
class OtherEnum(Enum):
519520
one = auto()
@@ -3291,10 +3292,10 @@ class Color(mixin, Enum):
32913292
RED = 1
32923293
GREEN = 2
32933294
BLUE = 3
3294-
Color._add_alias_(Color.RED, 'ROJO')
3295+
Color.RED._add_alias_('ROJO')
32953296
self.assertIs(Color.RED, Color['ROJO'])
32963297
self.assertIs(Color.RED, Color.ROJO)
3297-
Color._add_alias_(Color.BLUE, 'ORG')
3298+
Color.BLUE._add_alias_('ORG')
32983299
self.assertIs(Color.BLUE, Color['ORG'])
32993300
self.assertIs(Color.BLUE, Color.ORG)
33003301
self.assertEqual(Color.RED.ORG, 'huh')
@@ -3307,7 +3308,7 @@ class Color(Enum):
33073308
RED = 1
33083309
GREEN = 2
33093310
BLUE = 3
3310-
Color._add_value_alias_(Color.RED, 5)
3311+
Color.RED._add_value_alias_(5)
33113312
self.assertIs(Color.RED, Color(5))
33123313

33133314
def test_add_value_alias_during_creation(self):
@@ -3319,7 +3320,7 @@ def __new__(cls, int_value, *value_aliases):
33193320
member = object.__new__(cls)
33203321
member._value_ = int_value
33213322
for alias in value_aliases:
3322-
cls._add_value_alias_(member, alias)
3323+
member._add_value_alias_(alias)
33233324
return member
33243325
self.assertIs(Types(0), Types.Unknown)
33253326
self.assertIs(Types(1), Types.Source)
@@ -5000,12 +5001,14 @@ class CheckedColor(Enum):
50005001
@bltns.property
50015002
def zeroth(self):
50025003
return 'zeroed %s' % self.name
5003-
self.assertTrue(_test_simple_enum(CheckedColor, SimpleColor) is None)
5004+
_test_simple_enum(CheckedColor, SimpleColor)
50045005
SimpleColor.MAGENTA._value_ = 9
50055006
self.assertRaisesRegex(
50065007
TypeError, "enum mismatch",
50075008
_test_simple_enum, CheckedColor, SimpleColor,
50085009
)
5010+
#
5011+
#
50095012
class CheckedMissing(IntFlag, boundary=KEEP):
50105013
SIXTY_FOUR = 64
50115014
ONE_TWENTY_EIGHT = 128
@@ -5022,8 +5025,28 @@ class Missing:
50225025
ALL = 2048 + 128 + 64 + 12
50235026
M = Missing
50245027
self.assertEqual(list(CheckedMissing), [M.SIXTY_FOUR, M.ONE_TWENTY_EIGHT, M.TWENTY_FORTY_EIGHT])
5025-
#
50265028
_test_simple_enum(CheckedMissing, Missing)
5029+
#
5030+
#
5031+
class CheckedUnhashable(Enum):
5032+
ONE = dict()
5033+
TWO = set()
5034+
name = 'python'
5035+
self.assertIn(dict(), CheckedUnhashable)
5036+
self.assertIn('python', CheckedUnhashable)
5037+
self.assertEqual(CheckedUnhashable.name.value, 'python')
5038+
self.assertEqual(CheckedUnhashable.name.name, 'name')
5039+
#
5040+
@_simple_enum()
5041+
class Unhashable:
5042+
ONE = dict()
5043+
TWO = set()
5044+
name = 'python'
5045+
self.assertIn(dict(), Unhashable)
5046+
self.assertIn('python', Unhashable)
5047+
self.assertEqual(Unhashable.name.value, 'python')
5048+
self.assertEqual(Unhashable.name.name, 'name')
5049+
_test_simple_enum(Unhashable, Unhashable)
50275050

50285051

50295052
class MiscTestCase(unittest.TestCase):

0 commit comments

Comments
 (0)