|
1 | | -from enum import Enum, Flag, IntFlag, unique |
2 | | -from typing import Callable, Type, TypeVar, Union |
| 1 | +import warnings |
| 2 | +from enum import Enum, EnumMeta, Flag, IntFlag, unique |
| 3 | +from typing import Any, Callable, Iterable, Type, TypeVar, Union |
3 | 4 |
|
4 | 5 |
|
5 | 6 | M = TypeVar("M", bound=Enum) |
@@ -77,6 +78,46 @@ def coerced_new(enum: Type[M], value: object) -> M: |
77 | 78 | return enum_type |
78 | 79 |
|
79 | 80 |
|
| 81 | +class DeprecatedAccess(EnumMeta): |
| 82 | + """ |
| 83 | + runs a user-specified function whenever member is accessed |
| 84 | + """ |
| 85 | + |
| 86 | + def __getattribute__(cls, name: str) -> object: |
| 87 | + obj = super().__getattribute__(name) |
| 88 | + if isinstance(obj, Enum) and obj._on_access: # type: ignore |
| 89 | + obj._on_access() # type: ignore |
| 90 | + return obj |
| 91 | + |
| 92 | + def __getitem__(cls, name: str, *_: Iterable[Any]) -> object: # type: ignore |
| 93 | + member: Any = super().__getitem__(name) |
| 94 | + if member._on_access: |
| 95 | + member._on_access() |
| 96 | + return member |
| 97 | + |
| 98 | + def __call__( # type: ignore |
| 99 | + cls, |
| 100 | + value: str, |
| 101 | + names=None, |
| 102 | + *, |
| 103 | + module=None, |
| 104 | + qualname=None, |
| 105 | + type=None, |
| 106 | + start=1, |
| 107 | + ) -> object: |
| 108 | + obj = super().__call__( |
| 109 | + value, |
| 110 | + names, |
| 111 | + module=module, |
| 112 | + qualname=qualname, |
| 113 | + type=type, |
| 114 | + start=start, |
| 115 | + ) |
| 116 | + if isinstance(obj, Enum) and obj._on_access: |
| 117 | + obj._on_access() |
| 118 | + return obj |
| 119 | + |
| 120 | + |
80 | 121 | class StringyMixin: |
81 | 122 | """ |
82 | 123 | Mixin class for overloading __str__ on Enum types. |
@@ -206,19 +247,46 @@ class Delivery(StringyMixin, str, Enum): |
206 | 247 | DISK = "disk" |
207 | 248 |
|
208 | 249 |
|
| 250 | +def deprecated_enum(old_value: str, new_value: str) -> str: |
| 251 | + warnings.warn( |
| 252 | + f"{old_value} is deprecated to {new_value}", |
| 253 | + category=DeprecationWarning, |
| 254 | + stacklevel=3, # This makes the error happen in user code |
| 255 | + ) |
| 256 | + return new_value |
| 257 | + |
| 258 | + |
209 | 259 | @unique |
210 | 260 | @coercible |
211 | | -class SType(StringyMixin, str, Enum): |
| 261 | +class SType(StringyMixin, str, Enum, metaclass=DeprecatedAccess): |
212 | 262 | """Represents a symbology type.""" |
213 | 263 |
|
214 | | - PRODUCT_ID = "product_id" # Deprecated for `instrument_id` |
215 | | - NATIVE = "native" # Deprecated for `raw_symbol` |
216 | | - SMART = "smart" # Deprecated for `parent` and `continuous` |
| 264 | + PRODUCT_ID = "product_id", "instrument_id" # Deprecated for `instrument_id` |
| 265 | + NATIVE = "native", "raw_symbol" # Deprecated for `raw_symbol` |
| 266 | + SMART = "smart", "parent", "continuous" # Deprecated for `parent` and `continuous` |
217 | 267 | INSTRUMENT_ID = "instrument_id" |
218 | 268 | RAW_SYMBOL = "raw_symbol" |
219 | 269 | PARENT = "parent" |
220 | 270 | CONTINUOUS = "continuous" |
221 | 271 |
|
| 272 | + def __new__(cls, value: str, *args: Iterable[str]) -> "SType": |
| 273 | + variant = super().__new__(cls, value) |
| 274 | + variant._value_ = value |
| 275 | + variant.__args = args # type: ignore |
| 276 | + variant._on_access = variant.__deprecated if args else None # type: ignore |
| 277 | + return variant |
| 278 | + |
| 279 | + def __eq__(self, other: object) -> bool: |
| 280 | + return str(self) == str(other) |
| 281 | + |
| 282 | + def __deprecated(self) -> None: |
| 283 | + other_values = " or ".join(self.__args) # type: ignore |
| 284 | + warnings.warn( |
| 285 | + f"SType of {self.value} is deprecated; use {other_values}", |
| 286 | + category=DeprecationWarning, |
| 287 | + stacklevel=3, |
| 288 | + ) |
| 289 | + |
222 | 290 |
|
223 | 291 | @unique |
224 | 292 | @coercible |
|
0 commit comments