1010from io import BytesIO
1111from os import PathLike
1212from pathlib import Path
13- from typing import IO , TYPE_CHECKING , Any , BinaryIO , Callable , Literal , overload
13+ from typing import (
14+ IO ,
15+ TYPE_CHECKING ,
16+ Any ,
17+ BinaryIO ,
18+ Callable ,
19+ Literal ,
20+ Protocol ,
21+ overload ,
22+ )
1423
1524import databento_dbn
1625import numpy as np
@@ -1072,20 +1081,43 @@ def to_ndarray(
10721081
10731082 """
10741083 schema = validate_maybe_enum (schema , Schema , "schema" )
1075- if schema is None :
1076- if self .schema is None :
1084+ ndarray_iter : NDArrayIterator
1085+
1086+ if self .schema is None :
1087+ # If schema is None, we're handling heterogeneous data from the live client.
1088+ # This is less performant because the records of a given schema are not contiguous in memory.
1089+ if schema is None :
10771090 raise ValueError ("a schema must be specified for mixed DBN data" )
1078- schema = self .schema
10791091
1080- dtype = SCHEMA_DTYPES_MAP [schema ]
1081- ndarray_iter = NDArrayIterator (
1082- filter (lambda r : isinstance (r , SCHEMA_STRUCT_MAP [schema ]), self ),
1083- dtype ,
1084- count ,
1085- )
1092+ schema_struct = SCHEMA_STRUCT_MAP [schema ]
1093+ schema_dtype = SCHEMA_DTYPES_MAP [schema ]
1094+ schema_filter = filter (lambda r : isinstance (r , schema_struct ), self )
1095+
1096+ ndarray_iter = NDArrayBytesIterator (
1097+ records = map (bytes , schema_filter ),
1098+ dtype = schema_dtype ,
1099+ count = count ,
1100+ )
1101+ else :
1102+ # If schema is set, we're handling homogeneous historical data.
1103+ schema_dtype = SCHEMA_DTYPES_MAP [self .schema ]
1104+
1105+ if self ._metadata .ts_out :
1106+ schema_dtype .append (("ts_out" , "u8" ))
1107+
1108+ if schema is not None and schema != self .schema :
1109+ # This is to maintain identical behavior with NDArrayBytesIterator
1110+ ndarray_iter = iter ([np .empty ([0 , 1 ], dtype = schema_dtype )])
1111+ else :
1112+ ndarray_iter = NDArrayStreamIterator (
1113+ reader = self .reader ,
1114+ dtype = schema_dtype ,
1115+ offset = self ._metadata_length ,
1116+ count = count ,
1117+ )
10861118
10871119 if count is None :
1088- return next (ndarray_iter , np .empty ([0 , 1 ], dtype = dtype ))
1120+ return next (ndarray_iter , np .empty ([0 , 1 ], dtype = schema_dtype ))
10891121
10901122 return ndarray_iter
10911123
@@ -1124,10 +1156,66 @@ def _transcode(
11241156 transcoder .flush ()
11251157
11261158
1127- class NDArrayIterator :
1159+ class NDArrayIterator (Protocol ):
1160+ @abc .abstractmethod
1161+ def __iter__ (self ) -> NDArrayIterator :
1162+ ...
1163+
1164+ @abc .abstractmethod
1165+ def __next__ (self ) -> np .ndarray [Any , Any ]:
1166+ ...
1167+
1168+
1169+ class NDArrayStreamIterator (NDArrayIterator ):
1170+ """
1171+ Iterator for homogeneous byte streams of DBN records.
1172+ """
1173+
1174+ def __init__ (
1175+ self ,
1176+ reader : IO [bytes ],
1177+ dtype : list [tuple [str , str ]],
1178+ offset : int = 0 ,
1179+ count : int | None = None ,
1180+ ) -> None :
1181+ self ._reader = reader
1182+ self ._dtype = np .dtype (dtype )
1183+ self ._offset = offset
1184+ self ._count = count
1185+
1186+ self ._reader .seek (offset )
1187+
1188+ def __iter__ (self ) -> NDArrayStreamIterator :
1189+ return self
1190+
1191+ def __next__ (self ) -> np .ndarray [Any , Any ]:
1192+ if self ._count is None :
1193+ read_size = - 1
1194+ else :
1195+ read_size = self ._dtype .itemsize * max (self ._count , 1 )
1196+
1197+ if buffer := self ._reader .read (read_size ):
1198+ try :
1199+ return np .frombuffer (
1200+ buffer = buffer ,
1201+ dtype = self ._dtype ,
1202+ )
1203+ except ValueError :
1204+ raise BentoError (
1205+ "DBN file is truncated or contains an incomplete record" ,
1206+ )
1207+
1208+ raise StopIteration
1209+
1210+
1211+ class NDArrayBytesIterator (NDArrayIterator ):
1212+ """
1213+ Iterator for heterogeneous streams of DBN records.
1214+ """
1215+
11281216 def __init__ (
11291217 self ,
1130- records : Iterator [DBNRecord ],
1218+ records : Iterator [bytes ],
11311219 dtype : list [tuple [str , str ]],
11321220 count : int | None ,
11331221 ):
@@ -1144,22 +1232,33 @@ def __next__(self) -> np.ndarray[Any, Any]:
11441232 num_records = 0
11451233 for record in itertools .islice (self ._records , self ._count ):
11461234 num_records += 1
1147- record_bytes .write (bytes ( record ) )
1235+ record_bytes .write (record )
11481236
11491237 if num_records == 0 :
11501238 if self ._first_next :
11511239 return np .empty ([0 , 1 ], dtype = self ._dtype )
11521240 raise StopIteration
11531241
11541242 self ._first_next = False
1155- return np .frombuffer (
1156- record_bytes .getvalue (),
1157- dtype = self ._dtype ,
1158- count = num_records ,
1159- )
1243+
1244+ try :
1245+ return np .frombuffer (
1246+ record_bytes .getbuffer (),
1247+ dtype = self ._dtype ,
1248+ count = num_records ,
1249+ )
1250+ except ValueError :
1251+ raise BentoError (
1252+ "DBN file is truncated or contains an incomplete record" ,
1253+ )
11601254
11611255
11621256class DataFrameIterator :
1257+ """
1258+ Iterator for DataFrames that supports batching and column formatting for
1259+ DBN records.
1260+ """
1261+
11631262 def __init__ (
11641263 self ,
11651264 records : Iterator [np .ndarray [Any , Any ]],
0 commit comments