Skip to content

Commit 194c7c7

Browse files
ethanfurmanlexierule
authored andcommitted
Improve LigtningEnum, etc. (#12750)
1 parent fce8d4b commit 194c7c7

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

pytorch_lightning/utilities/enums.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ class LightningEnum(str, Enum):
2727

2828
@classmethod
2929
def from_str(cls, value: str) -> LightningEnum | None:
30-
statuses = [status for status in dir(cls) if not status.startswith("_")]
30+
statuses = cls.__members__.keys()
3131
for st in statuses:
3232
if st.lower() == value.lower():
33-
return getattr(cls, st)
33+
return cls[st]
3434
return None
3535

3636
def __eq__(self, other: object) -> bool:
@@ -43,21 +43,21 @@ def __hash__(self) -> int:
4343
return hash(self.value.lower())
4444

4545

46-
class _OnAccessEnumMeta(EnumMeta):
47-
"""Enum with a hook to run a function whenever a member is accessed.
46+
class _DeprecatedEnumMeta(EnumMeta):
47+
"""Enum that calls `deprecate()` whenever a member is accessed.
4848
49-
Adapted from:
50-
https://www.buzzphp.com/posts/how-do-i-detect-and-invoke-a-function-when-a-python-enum-member-is-accessed
49+
Adapted from: https://stackoverflow.com/a/62309159/208880
5150
"""
5251

5352
def __getattribute__(cls, name: str) -> Any:
5453
obj = super().__getattribute__(name)
55-
if isinstance(obj, Enum):
54+
# ignore __dunder__ names -- prevents potential recursion errors
55+
if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum):
5656
obj.deprecate()
5757
return obj
5858

5959
def __getitem__(cls, name: str) -> Any:
60-
member: _OnAccessEnumMeta = super().__getitem__(name)
60+
member: _DeprecatedEnumMeta = super().__getitem__(name)
6161
member.deprecate()
6262
return member
6363

@@ -68,6 +68,12 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any:
6868
return obj
6969

7070

71+
class _DeprecatedEnum(LightningEnum, metaclass=_DeprecatedEnumMeta):
72+
"""_DeprecatedEnum calls an enum's `deprecate()` method on member access."""
73+
74+
pass
75+
76+
7177
class AMPType(LightningEnum):
7278
"""Type of Automatic Mixed Precission used for training.
7379
@@ -104,7 +110,7 @@ def supported_types() -> list[str]:
104110
return [x.value for x in PrecisionType]
105111

106112

107-
class DistributedType(LightningEnum, metaclass=_OnAccessEnumMeta):
113+
class DistributedType(_DeprecatedEnum):
108114
"""Define type of training strategy.
109115
110116
Deprecated since v1.6.0 and will be removed in v1.8.0.
@@ -146,7 +152,7 @@ def deprecate(self) -> None:
146152
)
147153

148154

149-
class DeviceType(LightningEnum, metaclass=_OnAccessEnumMeta):
155+
class DeviceType(_DeprecatedEnum):
150156
"""Define Device type by its nature - accelerators.
151157
152158
Deprecated since v1.6.0 and will be removed in v1.8.0.

0 commit comments

Comments
 (0)