Skip to content

Commit bd0db0f

Browse files
committed
MOD: Add PriceType enum for validating price_type
1 parent 916e060 commit bd0db0f

File tree

4 files changed

+39
-8
lines changed

4 files changed

+39
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
This release drops support for Python 3.8 which has reached end-of-life.
66

77
#### Enhancements
8+
- Added `PriceType` enum for validation of `price_type` parameter in `DBNStore.to_df`
89
- Upgraded `databento-dbn` to 0.22.1
910

1011
#### Bug fixes

databento/common/dbnstore.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from databento.common.constants import DEFINITION_TYPE_MAX_MAP
4848
from databento.common.constants import SCHEMA_STRUCT_MAP
4949
from databento.common.constants import SCHEMA_STRUCT_MAP_V1
50+
from databento.common.enums import PriceType
5051
from databento.common.error import BentoError
5152
from databento.common.error import BentoWarning
5253
from databento.common.symbology import InstrumentMap
@@ -848,7 +849,7 @@ def to_csv(
848849
@overload
849850
def to_df(
850851
self,
851-
price_type: Literal["fixed", "float", "decimal"] = ...,
852+
price_type: PriceType | str = ...,
852853
pretty_ts: bool = ...,
853854
map_symbols: bool = ...,
854855
schema: Schema | str | None = ...,
@@ -859,7 +860,7 @@ def to_df(
859860
@overload
860861
def to_df(
861862
self,
862-
price_type: Literal["fixed", "float", "decimal"] = ...,
863+
price_type: PriceType | str = ...,
863864
pretty_ts: bool = ...,
864865
map_symbols: bool = ...,
865866
schema: Schema | str | None = ...,
@@ -869,7 +870,7 @@ def to_df(
869870

870871
def to_df(
871872
self,
872-
price_type: Literal["fixed", "float", "decimal"] = "float",
873+
price_type: PriceType | str = PriceType.FLOAT,
873874
pretty_ts: bool = True,
874875
map_symbols: bool = True,
875876
schema: Schema | str | None = None,
@@ -883,7 +884,7 @@ def to_df(
883884
884885
Parameters
885886
----------
886-
price_type : str, default "float"
887+
price_type : PriceType or str, default "float"
887888
The price type to use for price fields.
888889
If "fixed", prices will have a type of `int` in fixed decimal format; each unit representing 1e-9 or 0.000000001.
889890
If "float", prices will have a type of `float`.
@@ -918,6 +919,7 @@ def to_df(
918919
If the DBN schema is unspecified and cannot be determined.
919920
920921
"""
922+
price_type = validate_enum(price_type, PriceType, "price_type")
921923
schema = validate_maybe_enum(schema, Schema, "schema")
922924

923925
if isinstance(tz, Default):
@@ -1422,7 +1424,7 @@ def __init__(
14221424
struct_type: type[DBNRecord],
14231425
instrument_map: InstrumentMap,
14241426
tz: pytz.BaseTzInfo,
1425-
price_type: Literal["fixed", "float", "decimal"] = "float",
1427+
price_type: PriceType = PriceType.FLOAT,
14261428
pretty_ts: bool = True,
14271429
map_symbols: bool = True,
14281430
):
@@ -1499,16 +1501,16 @@ def _format_timezone(self, df: pd.DataFrame) -> None:
14991501
def _format_px(
15001502
self,
15011503
df: pd.DataFrame,
1502-
price_type: Literal["fixed", "float", "decimal"],
1504+
price_type: PriceType,
15031505
) -> None:
15041506
px_fields = self._struct_type._price_fields
15051507

1506-
if price_type == "decimal":
1508+
if price_type == PriceType.DECIMAL:
15071509
df[px_fields] = (
15081510
df[px_fields].replace(UNDEF_PRICE, np.nan).applymap(decimal.Decimal)
15091511
/ FIXED_PRICE_SCALE
15101512
)
1511-
elif price_type == "float":
1513+
elif price_type == PriceType.FLOAT:
15121514
df[px_fields] = df[px_fields].replace(UNDEF_PRICE, np.nan) / FIXED_PRICE_SCALE
15131515
else:
15141516
return # do nothing

databento/common/enums.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,5 +225,21 @@ class RecordFlags(StringyMixin, IntFlag): # type: ignore
225225
@unique
226226
@coercible
227227
class ReconnectPolicy(StringyMixin, str, Enum):
228+
"""
229+
Live session reconnection policy.
230+
"""
231+
228232
NONE = "none"
229233
RECONNECT = "reconnect"
234+
235+
236+
@unique
237+
@coercible
238+
class PriceType(StringyMixin, str, Enum):
239+
"""
240+
Price type for DataFrame price fields.
241+
"""
242+
243+
FIXED = "fixed"
244+
FLOAT = "float"
245+
DECIMAL = "decimal"

tests/test_historical_bento.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,18 @@ def test_to_df_with_price_type_handles_null(
646646
assert all(np.isnan(df_pretty["strike_price"]))
647647

648648

649+
def test_to_df_with_price_type_invalid(
650+
test_data: Callable[[Dataset, Schema], bytes],
651+
) -> None:
652+
# Arrange
653+
stub_data = test_data(Dataset.GLBX_MDP3, Schema.DEFINITION)
654+
data = DBNStore.from_bytes(data=stub_data)
655+
656+
# Act, Assert
657+
with pytest.raises(ValueError):
658+
data.to_df(price_type="US/Eastern")
659+
660+
649661
@pytest.mark.parametrize(
650662
"dataset",
651663
[

0 commit comments

Comments
 (0)