2929from databento_dbn import Compression
3030from databento_dbn import DBNDecoder
3131from databento_dbn import Encoding
32- from databento_dbn import ErrorMsg
32+ from databento_dbn import InstrumentDefMsg
33+ from databento_dbn import InstrumentDefMsgV1
3334from databento_dbn import Metadata
3435from databento_dbn import Schema
3536from databento_dbn import SType
36- from databento_dbn import SymbolMappingMsg
37- from databento_dbn import SystemMsg
3837from databento_dbn import Transcoder
38+ from databento_dbn import VersionUpgradePolicy
3939
40- from databento .common .data import DEFINITION_TYPE_MAX_MAP
41- from databento .common .data import SCHEMA_COLUMNS
42- from databento .common .data import SCHEMA_DTYPES_MAP
43- from databento .common .data import SCHEMA_STRUCT_MAP
40+ from databento .common .constants import DEFINITION_TYPE_MAX_MAP
41+ from databento .common .constants import INT64_NULL
42+ from databento .common .constants import SCHEMA_STRUCT_MAP
43+ from databento .common .constants import SCHEMA_STRUCT_MAP_V1
4444from databento .common .error import BentoError
45- from databento .common .iterator import chunk
4645from databento .common .symbology import InstrumentMap
46+ from databento .common .types import DBNRecord
4747from databento .common .validation import validate_enum
4848from databento .common .validation import validate_file_write_path
4949from databento .common .validation import validate_maybe_enum
50- from databento .live import DBNRecord
5150
5251
53- NON_SCHEMA_RECORD_TYPES = [
54- ErrorMsg ,
55- SymbolMappingMsg ,
56- SystemMsg ,
57- Metadata ,
58- ]
59-
60- INT64_NULL = 9223372036854775807
61-
6252logger = logging .getLogger (__name__ )
6353
6454if TYPE_CHECKING :
@@ -380,7 +370,9 @@ def __init__(self, data_source: DataSource) -> None:
380370
381371 def __iter__ (self ) -> Generator [DBNRecord , None , None ]:
382372 reader = self .reader
383- decoder = DBNDecoder ()
373+ decoder = DBNDecoder (
374+ upgrade_policy = VersionUpgradePolicy .UPGRADE ,
375+ )
384376 while True :
385377 raw = reader .read (DBNStore .DBN_READ_SIZE )
386378 if raw :
@@ -936,8 +928,8 @@ def to_df(
936928
937929 df_iter = DataFrameIterator (
938930 records = records ,
939- schema = schema ,
940931 count = count ,
932+ struct_type = self ._schema_struct_map [schema ],
941933 instrument_map = self ._instrument_map ,
942934 price_type = price_type ,
943935 pretty_ts = pretty_ts ,
@@ -1084,13 +1076,13 @@ def to_ndarray(
10841076 ndarray_iter : NDArrayIterator
10851077
10861078 if self .schema is None :
1087- # If schema is None, we're handling heterogeneous data from the live client.
1088- # This is less performant because the records of a given schema are not contiguous in memory.
1079+ # If schema is None, we're handling heterogeneous data from the live client
1080+ # This is less performant because the records of a given schema are not contiguous in memory
10891081 if schema is None :
10901082 raise ValueError ("a schema must be specified for mixed DBN data" )
10911083
1092- schema_struct = SCHEMA_STRUCT_MAP [schema ]
1093- schema_dtype = SCHEMA_DTYPES_MAP [ schema ]
1084+ schema_struct = self . _schema_struct_map [schema ]
1085+ schema_dtype = schema_struct . _dtypes
10941086 schema_filter = filter (lambda r : isinstance (r , schema_struct ), self )
10951087
10961088 ndarray_iter = NDArrayBytesIterator (
@@ -1099,8 +1091,9 @@ def to_ndarray(
10991091 count = count ,
11001092 )
11011093 else :
1102- # If schema is set, we're handling homogeneous historical data.
1103- schema_dtype = SCHEMA_DTYPES_MAP [self .schema ]
1094+ # If schema is set, we're handling homogeneous historical data
1095+ schema_struct = self ._schema_struct_map [self .schema ]
1096+ schema_dtype = schema_struct ._dtypes
11041097
11051098 if self ._metadata .ts_out :
11061099 schema_dtype .append (("ts_out" , "u8" ))
@@ -1145,15 +1138,36 @@ def _transcode(
11451138 pretty_ts = pretty_ts ,
11461139 has_metadata = True ,
11471140 map_symbols = map_symbols ,
1148- symbol_map = symbol_map , # type: ignore [arg-type]
1141+ symbol_interval_map = symbol_map , # type: ignore [arg-type]
11491142 schema = schema ,
11501143 )
11511144
1152- transcoder .write (bytes (self .metadata ))
1153- for records in chunk (self , 2 ** 16 ):
1154- for record in records :
1155- transcoder .write (bytes (record ))
1156- transcoder .flush ()
1145+ reader = self .reader
1146+ transcoder .write (reader .read (self ._metadata_length ))
1147+ while byte_chunk := reader .read (2 ** 16 ):
1148+ transcoder .write (byte_chunk )
1149+
1150+ if transcoder .buffer ():
1151+ raise BentoError (
1152+ "DBN file is truncated or contains an incomplete record" ,
1153+ )
1154+
1155+ transcoder .flush ()
1156+
1157+ @property
1158+ def _schema_struct_map (self ) -> dict [Schema , type [DBNRecord ]]:
1159+ """
1160+ Return a mapping of Schema variants to DBNRecord types based on the DBN
1161+ metadata version.
1162+
1163+ Returns
1164+ -------
1165+ dict[Schema, type[DBNRecord]]
1166+
1167+ """
1168+ if self .metadata .version == 1 :
1169+ return SCHEMA_STRUCT_MAP_V1
1170+ return SCHEMA_STRUCT_MAP
11571171
11581172
11591173class NDArrayIterator (Protocol ):
@@ -1263,31 +1277,30 @@ def __init__(
12631277 self ,
12641278 records : Iterator [np .ndarray [Any , Any ]],
12651279 count : int | None ,
1266- schema : Schema ,
1280+ struct_type : type [ DBNRecord ] ,
12671281 instrument_map : InstrumentMap ,
12681282 price_type : Literal ["fixed" , "float" , "decimal" ] = "float" ,
12691283 pretty_ts : bool = True ,
12701284 map_symbols : bool = True ,
12711285 ):
12721286 self ._records = records
1273- self ._schema = schema
12741287 self ._count = count
1288+ self ._struct_type = struct_type
12751289 self ._price_type = price_type
12761290 self ._pretty_ts = pretty_ts
12771291 self ._map_symbols = map_symbols
12781292 self ._instrument_map = instrument_map
1279- self ._struct = SCHEMA_STRUCT_MAP [schema ]
12801293
12811294 def __iter__ (self ) -> DataFrameIterator :
12821295 return self
12831296
12841297 def __next__ (self ) -> pd .DataFrame :
12851298 df = pd .DataFrame (
12861299 next (self ._records ),
1287- columns = SCHEMA_COLUMNS [ self ._schema ] ,
1300+ columns = self ._struct_type . _ordered_fields ,
12881301 )
12891302
1290- if self ._schema == Schema . DEFINITION :
1303+ if self ._struct_type in ( InstrumentDefMsg , InstrumentDefMsgV1 ) :
12911304 self ._format_definition_fields (df )
12921305
12931306 self ._format_hidden_fields (df )
@@ -1310,8 +1323,8 @@ def _format_definition_fields(self, df: pd.DataFrame) -> None:
13101323 df [column ] = df [column ].where (df [column ] != type_max , np .nan )
13111324
13121325 def _format_hidden_fields (self , df : pd .DataFrame ) -> None :
1313- for column , dtype in SCHEMA_DTYPES_MAP [ self ._schema ] :
1314- hidden_fields = self ._struct ._hidden_fields
1326+ for column , dtype in self ._struct_type . _dtypes :
1327+ hidden_fields = self ._struct_type ._hidden_fields
13151328 if dtype .startswith ("S" ) and column not in hidden_fields :
13161329 df [column ] = df [column ].str .decode ("utf-8" )
13171330
@@ -1328,7 +1341,7 @@ def _format_px(
13281341 df : pd .DataFrame ,
13291342 price_type : Literal ["fixed" , "float" , "decimal" ],
13301343 ) -> None :
1331- px_fields = self ._struct ._price_fields
1344+ px_fields = self ._struct_type ._price_fields
13321345
13331346 if price_type == "decimal" :
13341347 for field in px_fields :
@@ -1343,11 +1356,9 @@ def _format_px(
13431356 return # do nothing
13441357
13451358 def _format_pretty_ts (self , df : pd .DataFrame ) -> None :
1346- for field in self ._struct ._timestamp_fields :
1359+ for field in self ._struct_type ._timestamp_fields :
13471360 df [field ] = pd .to_datetime (df [field ], utc = True , errors = "coerce" )
13481361
13491362 def _format_set_index (self , df : pd .DataFrame ) -> None :
1350- index_column = (
1351- "ts_event" if self ._schema .value .startswith ("ohlcv" ) else "ts_recv"
1352- )
1363+ index_column = self ._struct_type ._ordered_fields [0 ]
13531364 df .set_index (index_column , inplace = True )
0 commit comments