11from __future__ import annotations
22
33import abc
4+ import decimal
45import itertools
56import logging
7+ import warnings
68from collections .abc import Generator
79from collections .abc import Iterator
10+ from functools import partial
811from io import BytesIO
912from os import PathLike
1013from pathlib import Path
11- from typing import IO , TYPE_CHECKING , Any , Callable , overload
14+ from typing import IO , TYPE_CHECKING , Any , Callable , Literal , overload
1215
1316import databento_dbn
1417import numpy as np
4447
4548INT64_NULL = 9223372036854775807
4649
47-
4850logger = logging .getLogger (__name__ )
4951
5052if TYPE_CHECKING :
@@ -92,37 +94,6 @@ def is_dbn(reader: IO[bytes]) -> bool:
9294 return reader .read (3 ) == b"DBN"
9395
9496
95- def format_dataframe (
96- df : pd .DataFrame ,
97- schema : Schema ,
98- pretty_px : bool ,
99- pretty_ts : bool ,
100- ) -> pd .DataFrame :
101- struct = SCHEMA_STRUCT_MAP [schema ]
102-
103- if schema == Schema .DEFINITION :
104- for column , type_max in DEFINITION_TYPE_MAX_MAP .items ():
105- if column in df .columns :
106- df [column ] = df [column ].where (df [column ] != type_max , np .nan )
107-
108- if pretty_px :
109- for px_field in struct ._price_fields :
110- df [px_field ] = df [px_field ].replace (INT64_NULL , np .nan ) / FIXED_PRICE_SCALE
111-
112- if pretty_ts :
113- for ts_field in struct ._timestamp_fields :
114- df [ts_field ] = pd .to_datetime (df [ts_field ], errors = "coerce" , utc = True )
115-
116- for column , dtype in SCHEMA_DTYPES_MAP [schema ]:
117- if dtype .startswith ("S" ) and column not in struct ._hidden_fields :
118- df [column ] = df [column ].str .decode ("utf-8" )
119-
120- index_column = "ts_event" if schema .value .startswith ("ohlcv" ) else "ts_recv"
121- df .set_index (index_column , inplace = True )
122-
123- return df
124-
125-
12697class DataSource (abc .ABC ):
12798 """
12899 Abstract base class for backing DBNStore instances with data.
@@ -791,7 +762,7 @@ def to_csv(
791762 path : Path | str ,
792763 pretty_px : bool = True ,
793764 pretty_ts : bool = True ,
794- map_symbols : bool | None = None ,
765+ map_symbols : bool = True ,
795766 schema : Schema | str | None = None ,
796767 ) -> None :
797768 """
@@ -826,8 +797,12 @@ def to_csv(
826797 Requires all the data to be brought up into memory to then be written.
827798
828799 """
800+ price_type : Literal ["fixed" , "float" ] = "fixed"
801+ if pretty_px is True :
802+ price_type = "float"
803+
829804 df_iter = self .to_df (
830- pretty_px = pretty_px ,
805+ price_type = price_type ,
831806 pretty_ts = pretty_ts ,
832807 map_symbols = map_symbols ,
833808 schema = schema ,
@@ -844,7 +819,8 @@ def to_csv(
844819 @overload
845820 def to_df (
846821 self ,
847- pretty_px : bool = ...,
822+ pretty_px : bool | None = ...,
823+ price_type : Literal ["fixed" , "float" , "decimal" ] = ...,
848824 pretty_ts : bool = ...,
849825 map_symbols : bool = ...,
850826 schema : Schema | str | None = ...,
@@ -855,7 +831,8 @@ def to_df(
855831 @overload
856832 def to_df (
857833 self ,
858- pretty_px : bool = ...,
834+ pretty_px : bool | None = ...,
835+ price_type : Literal ["fixed" , "float" , "decimal" ] = ...,
859836 pretty_ts : bool = ...,
860837 map_symbols : bool = ...,
861838 schema : Schema | str | None = ...,
@@ -865,7 +842,8 @@ def to_df(
865842
866843 def to_df (
867844 self ,
868- pretty_px : bool = True ,
845+ pretty_px : bool | None = None ,
846+ price_type : Literal ["fixed" , "float" , "decimal" ] = "float" ,
869847 pretty_ts : bool = True ,
870848 map_symbols : bool = True ,
871849 schema : Schema | str | None = None ,
@@ -877,9 +855,15 @@ def to_df(
877855 Parameters
878856 ----------
879857 pretty_px : bool, default True
858+ This parameter is deprecated and will be removed in a future release.
880859 If all price columns should be converted from `int` to `float` at
881860 the correct scale (using the fixed-precision scalar 1e-9). Null
882861 prices are replaced with NaN.
862+ price_type : str, default "float"
863+ The price type to use for price fields.
864+ If "fixed", prices will have a type of `int` in fixed decimal format; each unit representing 1e-9 or 0.000000001.
865+ If "float", prices will have a type of `float`.
866+ If "decimal", prices will be instances of `decimal.Decimal`.
883867 pretty_ts : bool, default True
884868 If all timestamp columns should be converted from UNIX nanosecond
885869 `int` to tz-aware UTC `pd.Timestamp`.
@@ -908,6 +892,20 @@ def to_df(
908892 If the schema for the array cannot be determined.
909893
910894 """
895+ if pretty_px is True :
896+ warnings .warn (
897+ 'The argument `pretty_px` is deprecated and will be removed in a future release; `price_type="float"` can be used instead.' ,
898+ DeprecationWarning ,
899+ stacklevel = 2 ,
900+ )
901+ elif pretty_px is False :
902+ price_type = "fixed"
903+ warnings .warn (
904+ 'The argument `pretty_px` is deprecated and will be removed in a future release; `price_type="fixed"` can be used instead.' ,
905+ DeprecationWarning ,
906+ stacklevel = 2 ,
907+ )
908+
911909 schema = validate_maybe_enum (schema , Schema , "schema" )
912910 if schema is None :
913911 if self .schema is None :
@@ -927,7 +925,7 @@ def to_df(
927925 schema = schema ,
928926 count = count ,
929927 instrument_map = self ._instrument_map ,
930- pretty_px = pretty_px ,
928+ price_type = price_type ,
931929 pretty_ts = pretty_ts ,
932930 map_symbols = map_symbols ,
933931 )
@@ -966,7 +964,7 @@ def to_json(
966964 path : Path | str ,
967965 pretty_px : bool = True ,
968966 pretty_ts : bool = True ,
969- map_symbols : bool | None = None ,
967+ map_symbols : bool = True ,
970968 schema : Schema | str | None = None ,
971969 ) -> None :
972970 """
@@ -1000,8 +998,12 @@ def to_json(
1000998 Requires all the data to be brought up into memory to then be written.
1001999
10021000 """
1001+ price_type : Literal ["fixed" , "float" ] = "fixed"
1002+ if pretty_px is True :
1003+ price_type = "float"
1004+
10031005 df_iter = self .to_df (
1004- pretty_px = pretty_px ,
1006+ price_type = price_type ,
10051007 pretty_ts = pretty_ts ,
10061008 map_symbols = map_symbols ,
10071009 schema = schema ,
@@ -1126,37 +1128,91 @@ def __init__(
11261128 count : int | None ,
11271129 schema : Schema ,
11281130 instrument_map : InstrumentMap ,
1129- pretty_px : bool = True ,
1131+ price_type : Literal [ "fixed" , "float" , "decimal" ] = "float" ,
11301132 pretty_ts : bool = True ,
11311133 map_symbols : bool = True ,
11321134 ):
11331135 self ._records = records
11341136 self ._schema = schema
11351137 self ._count = count
1136- self ._pretty_px = pretty_px
1138+ self ._price_type = price_type
11371139 self ._pretty_ts = pretty_ts
11381140 self ._map_symbols = map_symbols
11391141 self ._instrument_map = instrument_map
1142+ self ._struct = SCHEMA_STRUCT_MAP [schema ]
11401143
11411144 def __iter__ (self ) -> DataFrameIterator :
11421145 return self
11431146
11441147 def __next__ (self ) -> pd .DataFrame :
1145- df = format_dataframe (
1146- pd .DataFrame (
1147- next (self ._records ),
1148- columns = SCHEMA_COLUMNS [self ._schema ],
1149- ),
1150- schema = self ._schema ,
1151- pretty_px = self ._pretty_px ,
1152- pretty_ts = self ._pretty_ts ,
1148+ df = pd .DataFrame (
1149+ next (self ._records ),
1150+ columns = SCHEMA_COLUMNS [self ._schema ],
11531151 )
11541152
1153+ if self ._schema == Schema .DEFINITION :
1154+ self ._format_definition_fields (df )
1155+
1156+ self ._format_hidden_fields (df )
1157+
1158+ self ._format_px (df , self ._price_type )
1159+
1160+ if self ._pretty_ts :
1161+ self ._format_pretty_ts (df )
1162+
1163+ self ._format_set_index (df )
1164+
11551165 if self ._map_symbols :
1156- df_index = df .index if self ._pretty_ts else pd .to_datetime (df .index , utc = True )
1157- dates = [ts .date () for ts in df_index ]
1158- df ["symbol" ] = [
1159- self ._instrument_map .resolve (inst , dates [i ]) for i , inst in enumerate (df ["instrument_id" ])
1160- ]
1166+ self ._format_map_symbols (df )
11611167
11621168 return df
1169+
1170+ def _format_definition_fields (self , df : pd .DataFrame ) -> None :
1171+ for column , type_max in DEFINITION_TYPE_MAX_MAP .items ():
1172+ if column in df .columns :
1173+ df [column ] = df [column ].where (df [column ] != type_max , np .nan )
1174+
1175+ def _format_hidden_fields (self , df : pd .DataFrame ) -> None :
1176+ for column , dtype in SCHEMA_DTYPES_MAP [self ._schema ]:
1177+ hidden_fields = self ._struct ._hidden_fields
1178+ if dtype .startswith ("S" ) and column not in hidden_fields :
1179+ df [column ] = df [column ].str .decode ("utf-8" )
1180+
1181+ def _format_map_symbols (self , df : pd .DataFrame ) -> None :
1182+ df_index = df .index if self ._pretty_ts else pd .to_datetime (df .index , utc = True )
1183+ dates = [ts .date () for ts in df_index ]
1184+ df ["symbol" ] = [
1185+ self ._instrument_map .resolve (inst , dates [i ])
1186+ for i , inst in enumerate (df ["instrument_id" ])
1187+ ]
1188+
1189+ def _format_px (
1190+ self ,
1191+ df : pd .DataFrame ,
1192+ price_type : Literal ["fixed" , "float" , "decimal" ],
1193+ ) -> None :
1194+ px_fields = self ._struct ._price_fields
1195+
1196+ if price_type == "decimal" :
1197+ for field in px_fields :
1198+ df [field ] = (
1199+ df [field ].replace (INT64_NULL , np .nan ).apply (decimal .Decimal )
1200+ / FIXED_PRICE_SCALE
1201+ )
1202+ elif price_type == "float" :
1203+ for field in px_fields :
1204+ df [field ] = df [field ].replace (INT64_NULL , np .nan ) / FIXED_PRICE_SCALE
1205+ else :
1206+ return # do nothing
1207+
1208+ def _format_pretty_ts (self , df : pd .DataFrame ) -> None :
1209+ for field in self ._struct ._timestamp_fields :
1210+ df [field ] = df [field ].apply (
1211+ partial (pd .to_datetime , utc = True , errors = "coerce" ),
1212+ )
1213+
1214+ def _format_set_index (self , df : pd .DataFrame ) -> None :
1215+ index_column = (
1216+ "ts_event" if self ._schema .value .startswith ("ohlcv" ) else "ts_recv"
1217+ )
1218+ df .set_index (index_column , inplace = True )
0 commit comments