Skip to content

Commit ddf49bf

Browse files
committed
MOD: Enum type coercion enhancement
1 parent f4e281b commit ddf49bf

19 files changed

+566
-722
lines changed

databento/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
Delivery,
88
Encoding,
99
FeedMode,
10-
Flags,
1110
HistoricalGateway,
1211
LiveGateway,
1312
Packaging,
13+
RecordFlags,
1414
RollRule,
1515
Schema,
1616
SplitDuration,
@@ -41,7 +41,7 @@
4141
"Encoding",
4242
"FeedMode",
4343
"FileBento",
44-
"Flags",
44+
"RecordFlags",
4545
"Historical",
4646
"HistoricalGateway",
4747
"LiveGateway",

databento/common/bento.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ def symbology(self) -> Dict[str, Any]:
392392

393393
symbology: Dict[str, Any] = {
394394
"symbols": self.symbols,
395-
"stype_in": self.stype_in.value,
396-
"stype_out": self.stype_out.value,
395+
"stype_in": str(self.stype_in),
396+
"stype_out": str(self.stype_out),
397397
"start_date": str(self.start.date()),
398398
"end_date": str(self.end.date()),
399399
"partial": self._metadata["partial"],

databento/common/enums.py

Lines changed: 204 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,113 @@
1-
from enum import Enum, unique
1+
from enum import Enum, Flag, IntFlag, unique
2+
from typing import Callable, Type, TypeVar, Union
3+
4+
5+
M = TypeVar("M", bound=Enum)
6+
7+
8+
def coercible(enum_type: Type[M]) -> Type[M]:
9+
"""Decorate coercible enumerations.
10+
11+
Decorating an Enum class with this function will intercept calls to
12+
__new__ and perform a type coercion for the passed value. The type conversion
13+
function is chosen based on the subclass of the Enum type.
14+
15+
Currently supported subclasses types:
16+
int
17+
values are passed to int()
18+
str
19+
values are passed to str(), the result is also lowercased
20+
21+
Parameters
22+
----------
23+
enum_type : EnumMeta
24+
The deocrated Enum type.
25+
26+
Returns
27+
-------
28+
EnumMeta
29+
30+
Raises
31+
------
32+
ValueError
33+
When an invalid value of the Enum is given.
34+
35+
36+
Notes
37+
-----
38+
This decorator makes some assuptions about your Enum class.
39+
1. Your attribute names are all UPPERCASE
40+
2. Your attribute values are all lowercase
41+
42+
"""
43+
_new: Callable[[Type[M], object], M] = enum_type.__new__
44+
45+
def _cast_str(value: object) -> str:
46+
return str(value).lower()
47+
48+
coerce_fn: Callable[[object], Union[str, int]]
49+
if issubclass(enum_type, int):
50+
coerce_fn = int
51+
elif issubclass(enum_type, str):
52+
coerce_fn = _cast_str
53+
else:
54+
raise TypeError(f"{enum_type} does not a subclass a coercible type.")
55+
56+
def coerced_new(enum: Type[M], value: object) -> M:
57+
if value is None:
58+
raise TypeError(
59+
f"value `{value}` is not coercible to {enum_type.__name__}.",
60+
)
61+
try:
62+
return _new(enum, coerce_fn(value))
63+
except ValueError as ve:
64+
name_to_try = str(value).replace(".", "_").replace("-", "_").upper()
65+
named = enum._member_map_.get(name_to_try)
66+
if named is not None:
67+
return named
68+
enum_values = tuple(value for value in enum._value2member_map_)
69+
70+
raise ValueError(
71+
f"value `{value}` is not a member of {enum_type.__name__}. "
72+
f"use one of {enum_values}.",
73+
) from ve
74+
75+
setattr(enum_type, "__new__", coerced_new)
76+
77+
return enum_type
78+
79+
80+
class StringyMixin:
81+
"""
82+
Mixin class for overloading __str__ on Enum types.
83+
This will use the Enumerations subclass, if any, to modify
84+
the behavior of str().
85+
86+
For subclasses of enum.Flag a comma separated string of names is returned.
87+
For integer enumerations, the lowercase member name is returned.
88+
For string enumerations, the value is returned.
89+
90+
"""
91+
92+
def __str__(self) -> str:
93+
if isinstance(self, Flag):
94+
return ", ".join(f.name.lower() for f in self.__class__ if f in self)
95+
if isinstance(self, int):
96+
return getattr(self, "name").lower()
97+
return getattr(self, "value")
298

399

4100
@unique
5-
class HistoricalGateway(Enum):
101+
@coercible
102+
class HistoricalGateway(StringyMixin, str, Enum):
6103
"""Represents a historical data center gateway location."""
7104

8105
BO1 = "bo1"
9106

10107

11108
@unique
12-
class LiveGateway(Enum):
109+
@coercible
110+
class LiveGateway(StringyMixin, str, Enum):
13111
"""Represents a live data center gateway location."""
14112

15113
ORIGIN = "origin"
@@ -18,7 +116,8 @@ class LiveGateway(Enum):
18116

19117

20118
@unique
21-
class FeedMode(Enum):
119+
@coercible
120+
class FeedMode(StringyMixin, str, Enum):
22121
"""Represents a data feed mode."""
23122

24123
HISTORICAL = "historical"
@@ -27,15 +126,17 @@ class FeedMode(Enum):
27126

28127

29128
@unique
30-
class Dataset(Enum):
129+
@coercible
130+
class Dataset(StringyMixin, str, Enum):
31131
"""Represents a dataset code (string identifier)."""
32132

33133
GLBX_MDP3 = "GLBX.MDP3"
34134
XNAS_ITCH = "XNAS.ITCH"
35135

36136

37137
@unique
38-
class Schema(Enum):
138+
@coercible
139+
class Schema(StringyMixin, str, Enum):
39140
"""Represents a data record schema."""
40141

41142
MBO = "mbo"
@@ -53,9 +154,43 @@ class Schema(Enum):
53154
GATEWAY_ERROR = "gateway_error"
54155
SYMBOL_MAPPING = "symbol_mapping"
55156

157+
@classmethod
158+
def from_int(cls, value: int) -> "Schema":
159+
""" """
160+
if value == 0:
161+
return cls.MBO
162+
if value == 1:
163+
return cls.MBP_1
164+
if value == 2:
165+
return cls.MBP_10
166+
if value == 3:
167+
return cls.TBBO
168+
if value == 4:
169+
return cls.TRADES
170+
if value == 5:
171+
return cls.OHLCV_1S
172+
if value == 6:
173+
return cls.OHLCV_1M
174+
if value == 7:
175+
return cls.OHLCV_1H
176+
if value == 8:
177+
return cls.OHLCV_1D
178+
if value == 9:
179+
return cls.DEFINITION
180+
if value == 10:
181+
return cls.STATISTICS
182+
if value == 11:
183+
return cls.STATUS
184+
if value == 12:
185+
return cls.GATEWAY_ERROR
186+
if value == 13:
187+
return cls.SYMBOL_MAPPING
188+
raise ValueError(f"value `{value}` is not a valid member of {cls.__name__}")
189+
56190

57191
@unique
58-
class Encoding(Enum):
192+
@coercible
193+
class Encoding(StringyMixin, str, Enum):
59194
"""Represents a data output encoding."""
60195

61196
DBN = "dbn"
@@ -64,15 +199,26 @@ class Encoding(Enum):
64199

65200

66201
@unique
67-
class Compression(Enum):
202+
@coercible
203+
class Compression(StringyMixin, str, Enum):
68204
"""Represents a data compression format (if any)."""
69205

70206
NONE = "none"
71207
ZSTD = "zstd"
72208

209+
@classmethod
210+
def from_int(cls, value: int) -> "Compression":
211+
""" """
212+
if value == 0:
213+
return cls.NONE
214+
if value == 1:
215+
return cls.ZSTD
216+
raise ValueError(f"value `{value}` is not a valid member of {cls.__name__}")
217+
73218

74219
@unique
75-
class SplitDuration(Enum):
220+
@coercible
221+
class SplitDuration(StringyMixin, str, Enum):
76222
"""Represents the duration before splitting for each batched data file."""
77223

78224
DAY = "day"
@@ -82,7 +228,8 @@ class SplitDuration(Enum):
82228

83229

84230
@unique
85-
class Packaging(Enum):
231+
@coercible
232+
class Packaging(StringyMixin, str, Enum):
86233
"""Represents the packaging method for batched data files."""
87234

88235
NONE = "none"
@@ -91,7 +238,8 @@ class Packaging(Enum):
91238

92239

93240
@unique
94-
class Delivery(Enum):
241+
@coercible
242+
class Delivery(StringyMixin, str, Enum):
95243
"""Represents the delivery mechanism for batched data."""
96244

97245
DOWNLOAD = "download"
@@ -100,25 +248,39 @@ class Delivery(Enum):
100248

101249

102250
@unique
103-
class SType(Enum):
251+
@coercible
252+
class SType(StringyMixin, str, Enum):
104253
"""Represents a symbology type."""
105254

106255
PRODUCT_ID = "product_id"
107256
NATIVE = "native"
108257
SMART = "smart"
109258

259+
@classmethod
260+
def from_int(cls, value: int) -> "SType":
261+
""" """
262+
if value == 0:
263+
return cls.PRODUCT_ID
264+
if value == 1:
265+
return cls.NATIVE
266+
if value == 2:
267+
return cls.SMART
268+
raise ValueError(f"value `{value}` is not a valid member of {cls.__name__}")
269+
110270

111271
@unique
112-
class RollRule(Enum):
272+
@coercible
273+
class RollRule(StringyMixin, str, Enum):
113274
"""Represents a smart symbology roll rule."""
114275

115-
VOLUME = 0
116-
OPEN_INTEREST = 1
117-
CALENDAR = 2
276+
VOLUME = "volume"
277+
OPEN_INTEREST = "open_interst"
278+
CALENDAR = "calendar"
118279

119280

120281
@unique
121-
class SymbologyResolution(Enum):
282+
@coercible
283+
class SymbologyResolution(StringyMixin, str, Enum):
122284
"""
123285
Status code of symbology resolution.
124286
@@ -127,20 +289,31 @@ class SymbologyResolution(Enum):
127289
- NOT_FOUND: One or more symbols where not found on any date in range.
128290
"""
129291

130-
OK = 0
131-
PARTIAL = 1
132-
NOT_FOUND = 2
292+
OK = "ok"
293+
PARTIAL = "partial"
294+
NOT_FOUND = "not_found"
133295

134296

135297
@unique
136-
class Flags(Enum):
137-
"""Represents record flags."""
138-
139-
# Last message in the packet from the venue for a given `product_id`
140-
F_LAST = 1 << 7
141-
# Message sourced from a replay, such as a snapshot server
142-
F_SNAPSHOT = 1 << 5
143-
# Aggregated price level message, not an individual order
144-
F_MBP = 1 << 4
145-
# The `ts_recv` value is inaccurate (clock issues or reordering)
146-
F_BAD_TS_RECV = 1 << 3
298+
@coercible
299+
# Ignore type to work around mypy bug https://github.com/python/mypy/issues/9319
300+
class RecordFlags(StringyMixin, IntFlag): # type: ignore
301+
"""Represents record flags.
302+
303+
F_LAST
304+
Last message in the packet from the venue for a given `product_id`
305+
F_SNAPSHOT
306+
Message sourced from a replay, such as a snapshot server
307+
F_MBP
308+
Aggregated price level message, not an individual order
309+
F_BAD_TS_RECV
310+
The `ts_recv` value is inaccurate (clock issues or reordering)
311+
312+
Other bits are reserved and have no current meaning.
313+
314+
"""
315+
316+
F_LAST = 128
317+
F_SNAPSHOT = 32
318+
F_MBP = 16
319+
F_BAD_TS_RECV = 8

0 commit comments

Comments
 (0)