Skip to content

Commit e53d3ca

Browse files
committed
MOD: Improve performance of DBNStore.to_ndarray
1 parent 4048e4c commit e53d3ca

File tree

2 files changed

+122
-23
lines changed

2 files changed

+122
-23
lines changed

databento/common/dbnstore.py

Lines changed: 118 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,16 @@
1010
from io import BytesIO
1111
from os import PathLike
1212
from pathlib import Path
13-
from typing import IO, TYPE_CHECKING, Any, BinaryIO, Callable, Literal, overload
13+
from typing import (
14+
IO,
15+
TYPE_CHECKING,
16+
Any,
17+
BinaryIO,
18+
Callable,
19+
Literal,
20+
Protocol,
21+
overload,
22+
)
1423

1524
import databento_dbn
1625
import numpy as np
@@ -1072,20 +1081,43 @@ def to_ndarray(
10721081
10731082
"""
10741083
schema = validate_maybe_enum(schema, Schema, "schema")
1075-
if schema is None:
1076-
if self.schema is None:
1084+
ndarray_iter: NDArrayIterator
1085+
1086+
if self.schema is None:
1087+
# If schema is None, we're handling heterogeneous data from the live client.
1088+
# This is less performant because the records of a given schema are not contiguous in memory.
1089+
if schema is None:
10771090
raise ValueError("a schema must be specified for mixed DBN data")
1078-
schema = self.schema
10791091

1080-
dtype = SCHEMA_DTYPES_MAP[schema]
1081-
ndarray_iter = NDArrayIterator(
1082-
filter(lambda r: isinstance(r, SCHEMA_STRUCT_MAP[schema]), self),
1083-
dtype,
1084-
count,
1085-
)
1092+
schema_struct = SCHEMA_STRUCT_MAP[schema]
1093+
schema_dtype = SCHEMA_DTYPES_MAP[schema]
1094+
schema_filter = filter(lambda r: isinstance(r, schema_struct), self)
1095+
1096+
ndarray_iter = NDArrayBytesIterator(
1097+
records=map(bytes, schema_filter),
1098+
dtype=schema_dtype,
1099+
count=count,
1100+
)
1101+
else:
1102+
# If schema is set, we're handling homogeneous historical data.
1103+
schema_dtype = SCHEMA_DTYPES_MAP[self.schema]
1104+
1105+
if self._metadata.ts_out:
1106+
schema_dtype.append(("ts_out", "u8"))
1107+
1108+
if schema is not None and schema != self.schema:
1109+
# This is to maintain identical behavior with NDArrayBytesIterator
1110+
ndarray_iter = iter([np.empty([0, 1], dtype=schema_dtype)])
1111+
else:
1112+
ndarray_iter = NDArrayStreamIterator(
1113+
reader=self.reader,
1114+
dtype=schema_dtype,
1115+
offset=self._metadata_length,
1116+
count=count,
1117+
)
10861118

10871119
if count is None:
1088-
return next(ndarray_iter, np.empty([0, 1], dtype=dtype))
1120+
return next(ndarray_iter, np.empty([0, 1], dtype=schema_dtype))
10891121

10901122
return ndarray_iter
10911123

@@ -1124,10 +1156,66 @@ def _transcode(
11241156
transcoder.flush()
11251157

11261158

1127-
class NDArrayIterator:
1159+
class NDArrayIterator(Protocol):
1160+
@abc.abstractmethod
1161+
def __iter__(self) -> NDArrayIterator:
1162+
...
1163+
1164+
@abc.abstractmethod
1165+
def __next__(self) -> np.ndarray[Any, Any]:
1166+
...
1167+
1168+
1169+
class NDArrayStreamIterator(NDArrayIterator):
1170+
"""
1171+
Iterator for homogeneous byte streams of DBN records.
1172+
"""
1173+
1174+
def __init__(
1175+
self,
1176+
reader: IO[bytes],
1177+
dtype: list[tuple[str, str]],
1178+
offset: int = 0,
1179+
count: int | None = None,
1180+
) -> None:
1181+
self._reader = reader
1182+
self._dtype = np.dtype(dtype)
1183+
self._offset = offset
1184+
self._count = count
1185+
1186+
self._reader.seek(offset)
1187+
1188+
def __iter__(self) -> NDArrayStreamIterator:
1189+
return self
1190+
1191+
def __next__(self) -> np.ndarray[Any, Any]:
1192+
if self._count is None:
1193+
read_size = -1
1194+
else:
1195+
read_size = self._dtype.itemsize * max(self._count, 1)
1196+
1197+
if buffer := self._reader.read(read_size):
1198+
try:
1199+
return np.frombuffer(
1200+
buffer=buffer,
1201+
dtype=self._dtype,
1202+
)
1203+
except ValueError:
1204+
raise BentoError(
1205+
"DBN file is truncated or contains an incomplete record",
1206+
)
1207+
1208+
raise StopIteration
1209+
1210+
1211+
class NDArrayBytesIterator(NDArrayIterator):
1212+
"""
1213+
Iterator for heterogeneous streams of DBN records.
1214+
"""
1215+
11281216
def __init__(
11291217
self,
1130-
records: Iterator[DBNRecord],
1218+
records: Iterator[bytes],
11311219
dtype: list[tuple[str, str]],
11321220
count: int | None,
11331221
):
@@ -1144,22 +1232,33 @@ def __next__(self) -> np.ndarray[Any, Any]:
11441232
num_records = 0
11451233
for record in itertools.islice(self._records, self._count):
11461234
num_records += 1
1147-
record_bytes.write(bytes(record))
1235+
record_bytes.write(record)
11481236

11491237
if num_records == 0:
11501238
if self._first_next:
11511239
return np.empty([0, 1], dtype=self._dtype)
11521240
raise StopIteration
11531241

11541242
self._first_next = False
1155-
return np.frombuffer(
1156-
record_bytes.getvalue(),
1157-
dtype=self._dtype,
1158-
count=num_records,
1159-
)
1243+
1244+
try:
1245+
return np.frombuffer(
1246+
record_bytes.getbuffer(),
1247+
dtype=self._dtype,
1248+
count=num_records,
1249+
)
1250+
except ValueError:
1251+
raise BentoError(
1252+
"DBN file is truncated or contains an incomplete record",
1253+
)
11601254

11611255

11621256
class DataFrameIterator:
1257+
"""
1258+
Iterator for DataFrames that supports batching and column formatting for
1259+
DBN records.
1260+
"""
1261+
11631262
def __init__(
11641263
self,
11651264
records: Iterator[np.ndarray[Any, Any]],

tests/test_historical_bento.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -905,8 +905,8 @@ def test_dbnstore_to_ndarray_with_count(
905905
# Act
906906
dbnstore = DBNStore.from_bytes(data=dbn_stub_data)
907907

908-
nd_iter = dbnstore.to_ndarray(count=count)
909908
expected = dbnstore.to_ndarray()
909+
nd_iter = dbnstore.to_ndarray(count=count)
910910

911911
# Assert
912912
aggregator: list[np.ndarray[Any, Any]] = []
@@ -935,8 +935,8 @@ def test_dbnstore_to_ndarray_with_schema(
935935
# Act
936936
dbnstore = DBNStore.from_bytes(data=dbn_stub_data)
937937

938-
actual = dbnstore.to_ndarray(schema=schema)
939938
expected = dbnstore.to_ndarray()
939+
actual = dbnstore.to_ndarray(schema=schema)
940940

941941
# Assert
942942
for i, row in enumerate(actual):
@@ -1014,8 +1014,8 @@ def test_dbnstore_to_df_with_count(
10141014
# Act
10151015
dbnstore = DBNStore.from_bytes(data=dbn_stub_data)
10161016

1017-
df_iter = dbnstore.to_df(count=count)
10181017
expected = dbnstore.to_df()
1018+
df_iter = dbnstore.to_df(count=count)
10191019

10201020
# Assert
10211021
aggregator: list[pd.DataFrame] = []
@@ -1048,8 +1048,8 @@ def test_dbnstore_to_df_with_schema(
10481048
# Act
10491049
dbnstore = DBNStore.from_bytes(data=dbn_stub_data)
10501050

1051-
actual = dbnstore.to_df(schema=schema)
10521051
expected = dbnstore.to_df()
1052+
actual = dbnstore.to_df(schema=schema)
10531053

10541054
# Assert
10551055
assert actual.equals(expected)

0 commit comments

Comments
 (0)