Skip to content

Commit c1749aa

Browse files
committed
MOD: Use Transcoder for DBNStore encoders
1 parent 686ce41 commit c1749aa

File tree

8 files changed

+202
-172
lines changed

8 files changed

+202
-172
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
#### Enhancements
66
- Added `price_type` argument for `DBNStore.to_df` to specify if price fields should be `fixed`, `float` or `decimal.Decimal`
7+
- Upgraded `databento-dbn` to 0.12.0
8+
9+
#### Breaking Changes
10+
- Changed outputs of `DBNStore.to_csv` and `DBNStore.to_json` to match the encoding formats from the Databento API
711

812
#### Deprecations
913
- Deprecated `pretty_px` argument for `DBNStore.to_df` to be removed in a future release; the default `pretty_px=True` is now equivalent to `price_type="float"` and `pretty_px=False` is now equivalent to `price_type="fixed"`

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ The library is fully compatible with the latest distribution of Anaconda 3.8 and
3232
The minimum dependencies as found in the `pyproject.toml` are also listed below:
3333
- python = "^3.8"
3434
- aiohttp = "^3.8.3"
35-
- databento-dbn = "0.11.1"
35+
- databento-dbn = "0.12.0"
3636
- numpy= ">=1.23.5"
3737
- pandas = ">=1.5.3"
3838
- requests = ">=2.24.0"

databento/common/dbnstore.py

Lines changed: 91 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from io import BytesIO
1212
from os import PathLike
1313
from pathlib import Path
14-
from typing import IO, TYPE_CHECKING, Any, Callable, Literal, overload
14+
from typing import IO, TYPE_CHECKING, Any, BinaryIO, Callable, Literal, overload
1515

1616
import databento_dbn
1717
import numpy as np
@@ -20,19 +20,24 @@
2020
from databento_dbn import FIXED_PRICE_SCALE
2121
from databento_dbn import Compression
2222
from databento_dbn import DBNDecoder
23+
from databento_dbn import Encoding
2324
from databento_dbn import ErrorMsg
2425
from databento_dbn import Metadata
2526
from databento_dbn import Schema
2627
from databento_dbn import SType
2728
from databento_dbn import SymbolMappingMsg
2829
from databento_dbn import SystemMsg
30+
from databento_dbn import Transcoder
2931

3032
from databento.common.data import DEFINITION_TYPE_MAX_MAP
3133
from databento.common.data import SCHEMA_COLUMNS
3234
from databento.common.data import SCHEMA_DTYPES_MAP
3335
from databento.common.data import SCHEMA_STRUCT_MAP
3436
from databento.common.error import BentoError
37+
from databento.common.iterator import chunk
3538
from databento.common.symbology import InstrumentMap
39+
from databento.common.symbology import SymbolInterval
40+
from databento.common.validation import validate_enum
3641
from databento.common.validation import validate_file_write_path
3742
from databento.common.validation import validate_maybe_enum
3843
from databento.live import DBNRecord
@@ -763,6 +768,7 @@ def to_csv(
763768
pretty_px: bool = True,
764769
pretty_ts: bool = True,
765770
map_symbols: bool = True,
771+
compression: Compression | str = Compression.NONE,
766772
schema: Schema | str | None = None,
767773
) -> None:
768774
"""
@@ -783,6 +789,8 @@ def to_csv(
783789
If symbology mappings from the metadata should be used to create
784790
a 'symbol' column, mapping the instrument ID to its requested symbol for
785791
every record.
792+
compression : Compression or str, default `Compression.NONE`
793+
The output compression for writing.
786794
schema : Schema or str, optional
787795
The schema for the csv.
788796
This is only required when reading a DBN stream with mixed record types.
@@ -797,24 +805,33 @@ def to_csv(
797805
Requires all the data to be brought up into memory to then be written.
798806
799807
"""
800-
price_type: Literal["fixed", "float"] = "fixed"
801-
if pretty_px is True:
802-
price_type = "float"
808+
compression = validate_enum(compression, Compression, "compression")
809+
schema = validate_maybe_enum(schema, Schema, "schema")
810+
if schema is None:
811+
if self.schema is None:
812+
raise ValueError("a schema must be specified for mixed DBN data")
813+
schema = self.schema
803814

804-
df_iter = self.to_df(
805-
price_type=price_type,
806-
pretty_ts=pretty_ts,
807-
map_symbols=map_symbols,
808-
schema=schema,
809-
count=2**16,
810-
)
815+
record_type = SCHEMA_STRUCT_MAP[schema]
816+
record_iter = filter(lambda r: isinstance(r, record_type), self)
811817

812-
with open(path, "x", newline="") as csv_file:
813-
for i, frame in enumerate(df_iter):
814-
frame.to_csv(
815-
csv_file,
816-
header=(i == 0),
817-
)
818+
if map_symbols:
819+
self._instrument_map.insert_metadata(self.metadata)
820+
symbol_map = self._instrument_map._data
821+
else:
822+
symbol_map = None
823+
824+
with open(path, "xb") as output:
825+
self._transcode(
826+
output=output,
827+
records_iter=record_iter,
828+
encoding=Encoding.CSV,
829+
pretty_px=pretty_px,
830+
pretty_ts=pretty_ts,
831+
symbol_map=symbol_map,
832+
compression=compression,
833+
schema=schema,
834+
)
818835

819836
@overload
820837
def to_df(
@@ -965,6 +982,7 @@ def to_json(
965982
pretty_px: bool = True,
966983
pretty_ts: bool = True,
967984
map_symbols: bool = True,
985+
compression: Compression | str = Compression.NONE,
968986
schema: Schema | str | None = None,
969987
) -> None:
970988
"""
@@ -984,6 +1002,8 @@ def to_json(
9841002
If symbology mappings from the metadata should be used to create
9851003
a 'symbol' column, mapping the instrument ID to its requested symbol for
9861004
every record.
1005+
compression : Compression or str, default `Compression.NONE`
1006+
The output compression for writing.
9871007
schema : Schema or str, optional
9881008
The schema for the json.
9891009
This is only required when reading a DBN stream with mixed record types.
@@ -998,27 +1018,33 @@ def to_json(
9981018
Requires all the data to be brought up into memory to then be written.
9991019
10001020
"""
1001-
price_type: Literal["fixed", "float"] = "fixed"
1002-
if pretty_px is True:
1003-
price_type = "float"
1021+
compression = validate_enum(compression, Compression, "compression")
1022+
schema = validate_maybe_enum(schema, Schema, "schema")
1023+
if schema is None:
1024+
if self.schema is None:
1025+
raise ValueError("a schema must be specified for mixed DBN data")
1026+
schema = self.schema
10041027

1005-
df_iter = self.to_df(
1006-
price_type=price_type,
1007-
pretty_ts=pretty_ts,
1008-
map_symbols=map_symbols,
1009-
schema=schema,
1010-
count=2**16,
1011-
)
1028+
record_type = SCHEMA_STRUCT_MAP[schema]
1029+
record_iter = filter(lambda r: isinstance(r, record_type), self)
10121030

1013-
with open(path, "x") as json_path:
1014-
for frame in df_iter:
1015-
frame.reset_index(inplace=True)
1016-
frame.to_json(
1017-
json_path,
1018-
orient="records",
1019-
date_unit="ns",
1020-
lines=True,
1021-
)
1031+
if map_symbols:
1032+
self._instrument_map.insert_metadata(self.metadata)
1033+
symbol_map = self._instrument_map._data
1034+
else:
1035+
symbol_map = None
1036+
1037+
with open(path, "xb") as output:
1038+
self._transcode(
1039+
output=output,
1040+
records_iter=record_iter,
1041+
encoding=Encoding.JSON,
1042+
pretty_px=pretty_px,
1043+
pretty_ts=pretty_ts,
1044+
symbol_map=symbol_map,
1045+
compression=compression,
1046+
schema=schema,
1047+
)
10221048

10231049
@overload
10241050
def to_ndarray( # type: ignore [misc]
@@ -1085,6 +1111,35 @@ def to_ndarray(
10851111

10861112
return ndarray_iter
10871113

1114+
def _transcode(
1115+
self,
1116+
output: BinaryIO,
1117+
records_iter: Iterator[DBNRecord],
1118+
encoding: Encoding,
1119+
pretty_px: bool,
1120+
pretty_ts: bool,
1121+
symbol_map: dict[int, list[SymbolInterval]] | None,
1122+
compression: Compression,
1123+
schema: Schema,
1124+
) -> None:
1125+
transcoder = Transcoder(
1126+
file=output,
1127+
encoding=encoding,
1128+
compression=compression,
1129+
pretty_px=pretty_px,
1130+
pretty_ts=pretty_ts,
1131+
has_metadata=True,
1132+
input_compression=Compression.NONE,
1133+
symbol_map=symbol_map, # type: ignore [arg-type]
1134+
schema=schema,
1135+
)
1136+
1137+
transcoder.write(bytes(self.metadata))
1138+
for records in chunk(records_iter, 2**16):
1139+
for record in records:
1140+
transcoder.write(bytes(record))
1141+
transcoder.flush()
1142+
10881143

10891144
class NDArrayIterator:
10901145
def __init__(

databento/common/iterator.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from __future__ import annotations
2+
3+
import itertools
4+
from collections.abc import Iterable
5+
from typing import TypeVar
6+
7+
8+
_C = TypeVar("_C")
9+
10+
11+
def chunk(iterable: Iterable[_C], size: int) -> Iterable[tuple[_C, ...]]:
12+
"""
13+
Break an iterable into chunks with a length of at most `size`.
14+
15+
Parameters
16+
----------
17+
iterable: Iterable[_C]
18+
The iterable to break up.
19+
size : int
20+
The maximum size of each chunk.
21+
22+
Returns
23+
-------
24+
Iterable[_C]
25+
26+
Raises
27+
------
28+
ValueError
29+
If `size` is less than 1.
30+
31+
"""
32+
if size < 1:
33+
raise ValueError("size must be at least 1")
34+
35+
it = iter(iterable)
36+
return iter(lambda: tuple(itertools.islice(it, size)), ())

databento/live/protocol.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import itertools
54
import logging
65
from collections.abc import Iterable
76
from functools import singledispatchmethod
87
from numbers import Number
9-
from typing import TypeVar
108

119
import databento_dbn
1210
from databento_dbn import Schema
1311
from databento_dbn import SType
1412

1513
from databento.common import cram
1614
from databento.common.error import BentoError
15+
from databento.common.iterator import chunk
1716
from databento.common.parsing import optional_datetime_to_unix_nanoseconds
1817
from databento.common.parsing import optional_symbols_list_to_list
1918
from databento.common.publishers import Dataset
@@ -36,37 +35,6 @@
3635
logger = logging.getLogger(__name__)
3736

3837

39-
_C = TypeVar("_C")
40-
41-
42-
def chunk(iterable: Iterable[_C], size: int) -> Iterable[tuple[_C, ...]]:
43-
"""
44-
Break an iterable into chunks with a length of at most `size`.
45-
46-
Parameters
47-
----------
48-
iterable: Iterable[_C]
49-
The iterable to break up.
50-
size : int
51-
The maximum size of each chunk.
52-
53-
Returns
54-
-------
55-
Iterable[_C]
56-
57-
Raises
58-
------
59-
ValueError
60-
If `size` is less than 1.
61-
62-
"""
63-
if size < 1:
64-
raise ValueError("size must be at least 1")
65-
66-
it = iter(iterable)
67-
return iter(lambda: tuple(itertools.islice(it, size)), ())
68-
69-
7038
class DatabentoLiveProtocol(asyncio.BufferedProtocol):
7139
"""
7240
A BufferedProtocol implementation for the Databento live subscription

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ repository = "https://github.com/databento/databento-python"
2626
[tool.poetry.dependencies]
2727
python = "^3.8"
2828
aiohttp = "^3.8.3"
29-
databento-dbn = "0.11.1"
29+
databento-dbn = "0.12.0"
3030
numpy = ">=1.23.5"
3131
pandas = ">=1.5.3"
3232
requests = ">=2.24.0"

tests/test_common_iterator.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterable
4+
5+
import pytest
6+
from databento.common import iterator
7+
8+
9+
@pytest.mark.parametrize(
10+
"things, size, expected",
11+
[
12+
(
13+
"abcdefg",
14+
2,
15+
[
16+
("a", "b"),
17+
("c", "d"),
18+
("e", "f"),
19+
("g",),
20+
],
21+
),
22+
],
23+
)
24+
def test_chunk(
25+
things: Iterable[object],
26+
size: int,
27+
expected: Iterable[tuple[object]],
28+
) -> None:
29+
"""
30+
Test that an iterable is chunked property.
31+
"""
32+
chunks = [chunk for chunk in iterator.chunk(things, size)]
33+
assert chunks == expected

0 commit comments

Comments
 (0)