Skip to content

Commit 8e7c94e

Browse files
committed
ADD: DBNStore.to_df tz parameter
1 parent ee406b4 commit 8e7c94e

File tree

5 files changed

+124
-4
lines changed

5 files changed

+124
-4
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Changelog
22

3-
## 0.28.1 - TBD
3+
## 0.29.0 - TBD
4+
5+
#### Enhancements
6+
- Added `tz` parameter to `DBNStore.to_df` which will convert all timestamp fields from UTC to a specified timezone when used with `pretty_ts`
47

58
#### Bug fixes
69
- `Live.block_for_close` and `Live.wait_for_close` will now call `Live.stop` when a timeout is reached instead of `Live.terminate` to close the stream more gracefully

databento/common/dbnstore.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import pandas as pd
2727
import pyarrow as pa
2828
import pyarrow.parquet as pq
29+
import pytz
2930
import zstandard
3031
from databento_dbn import FIXED_PRICE_SCALE
3132
from databento_dbn import Compression
@@ -47,6 +48,7 @@
4748
from databento.common.error import BentoError
4849
from databento.common.symbology import InstrumentMap
4950
from databento.common.types import DBNRecord
51+
from databento.common.types import Default
5052
from databento.common.validation import validate_enum
5153
from databento.common.validation import validate_file_write_path
5254
from databento.common.validation import validate_maybe_enum
@@ -830,6 +832,7 @@ def to_df(
830832
pretty_ts: bool = ...,
831833
map_symbols: bool = ...,
832834
schema: Schema | str | None = ...,
835+
tz: pytz.BaseTzInfo | str = ...,
833836
count: None = ...,
834837
) -> pd.DataFrame:
835838
...
@@ -841,6 +844,7 @@ def to_df(
841844
pretty_ts: bool = ...,
842845
map_symbols: bool = ...,
843846
schema: Schema | str | None = ...,
847+
tz: pytz.BaseTzInfo | str = ...,
844848
count: int = ...,
845849
) -> DataFrameIterator:
846850
...
@@ -851,6 +855,7 @@ def to_df(
851855
pretty_ts: bool = True,
852856
map_symbols: bool = True,
853857
schema: Schema | str | None = None,
858+
tz: pytz.BaseTzInfo | str | Default[pytz.BaseTzInfo] = Default[pytz.BaseTzInfo](pytz.UTC),
854859
count: int | None = None,
855860
) -> pd.DataFrame | DataFrameIterator:
856861
"""
@@ -865,14 +870,16 @@ def to_df(
865870
If "decimal", prices will be instances of `decimal.Decimal`.
866871
pretty_ts : bool, default True
867872
If all timestamp columns should be converted from UNIX nanosecond
868-
`int` to tz-aware UTC `pd.Timestamp`.
873+
`int` to tz-aware `pd.Timestamp`. The timezone can be specified using the `tz` parameter.
869874
map_symbols : bool, default True
870875
If symbology mappings from the metadata should be used to create
871876
a 'symbol' column, mapping the instrument ID to its requested symbol for
872877
every record.
873878
schema : Schema or str, optional
874879
The DBN schema for the dataframe.
875880
This is only required when reading a DBN stream with mixed record types.
881+
tz : pytz.BaseTzInfo or str, default UTC
882+
If `pretty_ts` is `True`, all timestamps will be converted to the specified timezone.
876883
count : int, optional
877884
If set, instead of returning a single `DataFrame` a `DataFrameIterator`
878885
instance will be returned. When iterated, this object will yield
@@ -892,6 +899,14 @@ def to_df(
892899
893900
"""
894901
schema = validate_maybe_enum(schema, Schema, "schema")
902+
903+
if isinstance(tz, Default):
904+
tz = tz.value # consume default
905+
elif not pretty_ts:
906+
raise ValueError("A timezone was specified when `pretty_ts` is `False`. Did you mean to set `pretty_ts=True`?")
907+
908+
if not isinstance(tz, pytz.BaseTzInfo):
909+
tz = pytz.timezone(tz)
895910
if schema is None:
896911
if self.schema is None:
897912
raise ValueError("a schema must be specified for mixed DBN data")
@@ -910,6 +925,7 @@ def to_df(
910925
count=count,
911926
struct_type=self._schema_struct_map[schema],
912927
instrument_map=self._instrument_map,
928+
tz=tz,
913929
price_type=price_type,
914930
pretty_ts=pretty_ts,
915931
map_symbols=map_symbols,
@@ -1334,6 +1350,7 @@ def __init__(
13341350
count: int | None,
13351351
struct_type: type[DBNRecord],
13361352
instrument_map: InstrumentMap,
1353+
tz: pytz.BaseTzInfo,
13371354
price_type: Literal["fixed", "float", "decimal"] = "float",
13381355
pretty_ts: bool = True,
13391356
map_symbols: bool = True,
@@ -1345,6 +1362,7 @@ def __init__(
13451362
self._pretty_ts = pretty_ts
13461363
self._map_symbols = map_symbols
13471364
self._instrument_map = instrument_map
1365+
self._tz = tz
13481366

13491367
def __iter__(self) -> DataFrameIterator:
13501368
return self
@@ -1411,7 +1429,7 @@ def _format_px(
14111429

14121430
def _format_pretty_ts(self, df: pd.DataFrame) -> None:
14131431
for field in self._struct_type._timestamp_fields:
1414-
df[field] = pd.to_datetime(df[field], utc=True, errors="coerce")
1432+
df[field] = pd.to_datetime(df[field], utc=True, errors="coerce").dt.tz_convert(self._tz)
14151433

14161434
def _format_set_index(self, df: pd.DataFrame) -> None:
14171435
index_column = self._struct_type._ordered_fields[0]

databento/common/types.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Union
1+
from typing import Callable, Generic, TypeVar, Union
22

33
import databento_dbn
44

@@ -21,3 +21,34 @@
2121

2222
RecordCallback = Callable[[DBNRecord], None]
2323
ExceptionCallback = Callable[[Exception], None]
24+
25+
_T = TypeVar("_T")
26+
class Default(Generic[_T]):
27+
"""
28+
A container for a default value. This is to be used when a callable wants
29+
to detect if a default parameter value is being used.
30+
31+
Example
32+
-------
33+
def foo(param=Default[int](10)):
34+
if isinstance(param, Default):
35+
print(f"param={param.value} (default)")
36+
else:
37+
print(f"param={param.value}")
38+
39+
"""
40+
41+
def __init__(self, value: _T):
42+
self._value = value
43+
44+
@property
45+
def value(self) -> _T:
46+
"""
47+
The default value.
48+
49+
Returns
50+
-------
51+
_T
52+
53+
"""
54+
return self._value

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ ruff = "^0.0.291"
5151
types-requests = "^2.30.0.0"
5252
tomli = "^2.0.1"
5353
teamcity-messages = "^1.32"
54+
types-pytz = "^2024.1.0.20240203"
5455

5556
[build-system]
5657
requires = ["poetry-core"]

tests/test_historical_bento.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
import numpy as np
1414
import pandas as pd
1515
import pytest
16+
import pytz
1617
import zstandard
18+
from databento.common.constants import SCHEMA_STRUCT_MAP
1719
from databento.common.dbnstore import DBNStore
1820
from databento.common.error import BentoError
1921
from databento.common.publishers import Dataset
@@ -1330,3 +1332,68 @@ def test_dbnstore_to_df_cannot_map_symbols_default_to_false(
13301332

13311333
# Assert
13321334
assert len(df_iter) == 4
1335+
1336+
1337+
@pytest.mark.parametrize(
1338+
"timezone",
1339+
[
1340+
"US/Central",
1341+
"US/Eastern",
1342+
"Europe/Vienna",
1343+
"Asia/Dubai",
1344+
"UTC",
1345+
],
1346+
)
1347+
@pytest.mark.parametrize(
1348+
"schema",
1349+
[pytest.param(schema, id=str(schema)) for schema in Schema.variants()],
1350+
)
1351+
def test_dbnstore_to_df_with_timezone(
1352+
test_data: Callable[[Dataset, Schema], bytes],
1353+
schema: Schema,
1354+
timezone: str,
1355+
) -> None:
1356+
"""
1357+
Test that setting the `tz` parameter in `DBNStore.to_df` converts all
1358+
timestamp fields into the specified timezone.
1359+
"""
1360+
# Arrange
1361+
dbn_stub_data = (
1362+
zstandard.ZstdDecompressor().stream_reader(test_data(Dataset.GLBX_MDP3, schema)).read()
1363+
)
1364+
dbnstore = DBNStore.from_bytes(data=dbn_stub_data)
1365+
1366+
# Act
1367+
df = dbnstore.to_df(tz=timezone)
1368+
df.reset_index(inplace=True)
1369+
1370+
# Assert
1371+
expected_timezone = pytz.timezone(timezone)._utcoffset
1372+
failures = []
1373+
struct = SCHEMA_STRUCT_MAP[schema]
1374+
for field in struct._timestamp_fields:
1375+
if df[field].dt.tz._utcoffset != expected_timezone:
1376+
failures.append(field)
1377+
1378+
assert not failures
1379+
1380+
1381+
def test_dbnstore_to_df_with_timezone_pretty_ts_error(
1382+
test_data: Callable[[Dataset, Schema], bytes],
1383+
) -> None:
1384+
"""
1385+
Test that setting the `tz` parameter in `DBNStore.to_df` when `pretty_ts`
1386+
is `False` causes an error.
1387+
"""
1388+
# Arrange
1389+
dbn_stub_data = (
1390+
zstandard.ZstdDecompressor().stream_reader(test_data(Dataset.GLBX_MDP3, Schema.MBO)).read()
1391+
)
1392+
dbnstore = DBNStore.from_bytes(data=dbn_stub_data)
1393+
1394+
# Act, Assert
1395+
with pytest.raises(ValueError):
1396+
dbnstore.to_df(
1397+
pretty_ts=False,
1398+
tz=pytz.UTC,
1399+
)

0 commit comments

Comments
 (0)