Skip to content

Commit d6f4892

Browse files
committed
ADD: Python support for decimal.Decimal
1 parent c84a06e commit d6f4892

File tree

3 files changed

+142
-64
lines changed

3 files changed

+142
-64
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# Changelog
22

3+
## 0.22.0 - TBD
4+
5+
#### Enhancements
6+
- Added `price_type` argument for `DBNStore.to_df` to specify if price fields should be `fixed`, `float` or `decimal.Decimal`
7+
8+
#### Deprecations
9+
- Deprecated `pretty_px` argument for `DBNStore.to_df` to be removed in a future release; the default `pretty_px=True` is now equivalent to `price_type="float"` and `pretty_px=False` is now equivalent to `price_type="fixed"`
10+
311
## 0.21.0 - 2023-10-11
412

513
#### Enhancements

databento/common/dbnstore.py

Lines changed: 112 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
from __future__ import annotations
22

33
import abc
4+
import decimal
45
import itertools
56
import logging
7+
import warnings
68
from collections.abc import Generator
79
from collections.abc import Iterator
10+
from functools import partial
811
from io import BytesIO
912
from os import PathLike
1013
from pathlib import Path
11-
from typing import IO, TYPE_CHECKING, Any, Callable, overload
14+
from typing import IO, TYPE_CHECKING, Any, Callable, Literal, overload
1215

1316
import databento_dbn
1417
import numpy as np
@@ -44,7 +47,6 @@
4447

4548
INT64_NULL = 9223372036854775807
4649

47-
4850
logger = logging.getLogger(__name__)
4951

5052
if TYPE_CHECKING:
@@ -92,37 +94,6 @@ def is_dbn(reader: IO[bytes]) -> bool:
9294
return reader.read(3) == b"DBN"
9395

9496

95-
def format_dataframe(
96-
df: pd.DataFrame,
97-
schema: Schema,
98-
pretty_px: bool,
99-
pretty_ts: bool,
100-
) -> pd.DataFrame:
101-
struct = SCHEMA_STRUCT_MAP[schema]
102-
103-
if schema == Schema.DEFINITION:
104-
for column, type_max in DEFINITION_TYPE_MAX_MAP.items():
105-
if column in df.columns:
106-
df[column] = df[column].where(df[column] != type_max, np.nan)
107-
108-
if pretty_px:
109-
for px_field in struct._price_fields:
110-
df[px_field] = df[px_field].replace(INT64_NULL, np.nan) / FIXED_PRICE_SCALE
111-
112-
if pretty_ts:
113-
for ts_field in struct._timestamp_fields:
114-
df[ts_field] = pd.to_datetime(df[ts_field], errors="coerce", utc=True)
115-
116-
for column, dtype in SCHEMA_DTYPES_MAP[schema]:
117-
if dtype.startswith("S") and column not in struct._hidden_fields:
118-
df[column] = df[column].str.decode("utf-8")
119-
120-
index_column = "ts_event" if schema.value.startswith("ohlcv") else "ts_recv"
121-
df.set_index(index_column, inplace=True)
122-
123-
return df
124-
125-
12697
class DataSource(abc.ABC):
12798
"""
12899
Abstract base class for backing DBNStore instances with data.
@@ -791,7 +762,7 @@ def to_csv(
791762
path: Path | str,
792763
pretty_px: bool = True,
793764
pretty_ts: bool = True,
794-
map_symbols: bool | None = None,
765+
map_symbols: bool = True,
795766
schema: Schema | str | None = None,
796767
) -> None:
797768
"""
@@ -826,8 +797,12 @@ def to_csv(
826797
Requires all the data to be brought up into memory to then be written.
827798
828799
"""
800+
price_type: Literal["fixed", "float"] = "fixed"
801+
if pretty_px is True:
802+
price_type = "float"
803+
829804
df_iter = self.to_df(
830-
pretty_px=pretty_px,
805+
price_type=price_type,
831806
pretty_ts=pretty_ts,
832807
map_symbols=map_symbols,
833808
schema=schema,
@@ -844,7 +819,8 @@ def to_csv(
844819
@overload
845820
def to_df(
846821
self,
847-
pretty_px: bool = ...,
822+
pretty_px: bool | None = ...,
823+
price_type: Literal["fixed", "float", "decimal"] = ...,
848824
pretty_ts: bool = ...,
849825
map_symbols: bool = ...,
850826
schema: Schema | str | None = ...,
@@ -855,7 +831,8 @@ def to_df(
855831
@overload
856832
def to_df(
857833
self,
858-
pretty_px: bool = ...,
834+
pretty_px: bool | None = ...,
835+
price_type: Literal["fixed", "float", "decimal"] = ...,
859836
pretty_ts: bool = ...,
860837
map_symbols: bool = ...,
861838
schema: Schema | str | None = ...,
@@ -865,7 +842,8 @@ def to_df(
865842

866843
def to_df(
867844
self,
868-
pretty_px: bool = True,
845+
pretty_px: bool | None = None,
846+
price_type: Literal["fixed", "float", "decimal"] = "float",
869847
pretty_ts: bool = True,
870848
map_symbols: bool = True,
871849
schema: Schema | str | None = None,
@@ -877,9 +855,15 @@ def to_df(
877855
Parameters
878856
----------
879857
pretty_px : bool, default True
858+
This parameter is deprecated and will be removed in a future release.
880859
If all price columns should be converted from `int` to `float` at
881860
the correct scale (using the fixed-precision scalar 1e-9). Null
882861
prices are replaced with NaN.
862+
price_type : str, default "float"
863+
The price type to use for price fields.
864+
If "fixed", prices will have a type of `int` in fixed decimal format; each unit representing 1e-9 or 0.000000001.
865+
If "float", prices will have a type of `float`.
866+
If "decimal", prices will be instances of `decimal.Decimal`.
883867
pretty_ts : bool, default True
884868
If all timestamp columns should be converted from UNIX nanosecond
885869
`int` to tz-aware UTC `pd.Timestamp`.
@@ -908,6 +892,20 @@ def to_df(
908892
If the schema for the array cannot be determined.
909893
910894
"""
895+
if pretty_px is True:
896+
warnings.warn(
897+
'The argument `pretty_px` is deprecated and will be removed in a future release; `price_type="float"` can be used instead.',
898+
DeprecationWarning,
899+
stacklevel=2,
900+
)
901+
elif pretty_px is False:
902+
price_type = "fixed"
903+
warnings.warn(
904+
'The argument `pretty_px` is deprecated and will be removed in a future release; `price_type="fixed"` can be used instead.',
905+
DeprecationWarning,
906+
stacklevel=2,
907+
)
908+
911909
schema = validate_maybe_enum(schema, Schema, "schema")
912910
if schema is None:
913911
if self.schema is None:
@@ -927,7 +925,7 @@ def to_df(
927925
schema=schema,
928926
count=count,
929927
instrument_map=self._instrument_map,
930-
pretty_px=pretty_px,
928+
price_type=price_type,
931929
pretty_ts=pretty_ts,
932930
map_symbols=map_symbols,
933931
)
@@ -966,7 +964,7 @@ def to_json(
966964
path: Path | str,
967965
pretty_px: bool = True,
968966
pretty_ts: bool = True,
969-
map_symbols: bool | None = None,
967+
map_symbols: bool = True,
970968
schema: Schema | str | None = None,
971969
) -> None:
972970
"""
@@ -1000,8 +998,12 @@ def to_json(
1000998
Requires all the data to be brought up into memory to then be written.
1001999
10021000
"""
1001+
price_type: Literal["fixed", "float"] = "fixed"
1002+
if pretty_px is True:
1003+
price_type = "float"
1004+
10031005
df_iter = self.to_df(
1004-
pretty_px=pretty_px,
1006+
price_type=price_type,
10051007
pretty_ts=pretty_ts,
10061008
map_symbols=map_symbols,
10071009
schema=schema,
@@ -1126,37 +1128,91 @@ def __init__(
11261128
count: int | None,
11271129
schema: Schema,
11281130
instrument_map: InstrumentMap,
1129-
pretty_px: bool = True,
1131+
price_type: Literal["fixed", "float", "decimal"] = "float",
11301132
pretty_ts: bool = True,
11311133
map_symbols: bool = True,
11321134
):
11331135
self._records = records
11341136
self._schema = schema
11351137
self._count = count
1136-
self._pretty_px = pretty_px
1138+
self._price_type = price_type
11371139
self._pretty_ts = pretty_ts
11381140
self._map_symbols = map_symbols
11391141
self._instrument_map = instrument_map
1142+
self._struct = SCHEMA_STRUCT_MAP[schema]
11401143

11411144
def __iter__(self) -> DataFrameIterator:
11421145
return self
11431146

11441147
def __next__(self) -> pd.DataFrame:
1145-
df = format_dataframe(
1146-
pd.DataFrame(
1147-
next(self._records),
1148-
columns=SCHEMA_COLUMNS[self._schema],
1149-
),
1150-
schema=self._schema,
1151-
pretty_px=self._pretty_px,
1152-
pretty_ts=self._pretty_ts,
1148+
df = pd.DataFrame(
1149+
next(self._records),
1150+
columns=SCHEMA_COLUMNS[self._schema],
11531151
)
11541152

1153+
if self._schema == Schema.DEFINITION:
1154+
self._format_definition_fields(df)
1155+
1156+
self._format_hidden_fields(df)
1157+
1158+
self._format_px(df, self._price_type)
1159+
1160+
if self._pretty_ts:
1161+
self._format_pretty_ts(df)
1162+
1163+
self._format_set_index(df)
1164+
11551165
if self._map_symbols:
1156-
df_index = df.index if self._pretty_ts else pd.to_datetime(df.index, utc=True)
1157-
dates = [ts.date() for ts in df_index]
1158-
df["symbol"] = [
1159-
self._instrument_map.resolve(inst, dates[i]) for i, inst in enumerate(df["instrument_id"])
1160-
]
1166+
self._format_map_symbols(df)
11611167

11621168
return df
1169+
1170+
def _format_definition_fields(self, df: pd.DataFrame) -> None:
1171+
for column, type_max in DEFINITION_TYPE_MAX_MAP.items():
1172+
if column in df.columns:
1173+
df[column] = df[column].where(df[column] != type_max, np.nan)
1174+
1175+
def _format_hidden_fields(self, df: pd.DataFrame) -> None:
1176+
for column, dtype in SCHEMA_DTYPES_MAP[self._schema]:
1177+
hidden_fields = self._struct._hidden_fields
1178+
if dtype.startswith("S") and column not in hidden_fields:
1179+
df[column] = df[column].str.decode("utf-8")
1180+
1181+
def _format_map_symbols(self, df: pd.DataFrame) -> None:
1182+
df_index = df.index if self._pretty_ts else pd.to_datetime(df.index, utc=True)
1183+
dates = [ts.date() for ts in df_index]
1184+
df["symbol"] = [
1185+
self._instrument_map.resolve(inst, dates[i])
1186+
for i, inst in enumerate(df["instrument_id"])
1187+
]
1188+
1189+
def _format_px(
1190+
self,
1191+
df: pd.DataFrame,
1192+
price_type: Literal["fixed", "float", "decimal"],
1193+
) -> None:
1194+
px_fields = self._struct._price_fields
1195+
1196+
if price_type == "decimal":
1197+
for field in px_fields:
1198+
df[field] = (
1199+
df[field].replace(INT64_NULL, np.nan).apply(decimal.Decimal)
1200+
/ FIXED_PRICE_SCALE
1201+
)
1202+
elif price_type == "float":
1203+
for field in px_fields:
1204+
df[field] = df[field].replace(INT64_NULL, np.nan) / FIXED_PRICE_SCALE
1205+
else:
1206+
return # do nothing
1207+
1208+
def _format_pretty_ts(self, df: pd.DataFrame) -> None:
1209+
for field in self._struct._timestamp_fields:
1210+
df[field] = df[field].apply(
1211+
partial(pd.to_datetime, utc=True, errors="coerce"),
1212+
)
1213+
1214+
def _format_set_index(self, df: pd.DataFrame) -> None:
1215+
index_column = (
1216+
"ts_event" if self._schema.value.startswith("ohlcv") else "ts_recv"
1217+
)
1218+
df.set_index(index_column, inplace=True)

0 commit comments

Comments
 (0)