Skip to content

Commit 1fb3ddf

Browse files
committed
MOD: Ensure empty dataframe columns
1 parent 683bba1 commit 1fb3ddf

File tree

3 files changed

+138
-112
lines changed

3 files changed

+138
-112
lines changed

databento/common/bento.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -275,13 +275,13 @@ def __init__(self, data_source: DataSource) -> None:
275275
byteorder="little",
276276
)
277277

278-
buffer.seek(0) # rewind to read the entire header
278+
buffer.seek(0) # Rewind to read the entire header
279279

280280
self._metadata: Dict[str, Any] = MetadataDecoder.decode_to_json(
281281
raw_metadata=buffer.read(8 + metadata_length),
282282
)
283283

284-
# This is populated when _map_symbols is called.
284+
# This is populated when _map_symbols is called
285285
self._product_id_index: Dict[
286286
dt.date,
287287
Dict[int, str],
@@ -354,10 +354,10 @@ def _build_product_id_index(self) -> Dict[dt.date, Dict[int, str]]:
354354

355355
return product_id_index
356356

357-
def _cleanup_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
357+
def _prepare_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
358+
df.set_index(self._get_index_column(), inplace=True)
358359
df.drop(["length", "rtype"], axis=1, inplace=True)
359360
if self.schema == Schema.MBO or self.schema in DERIV_SCHEMAS:
360-
df = df.reindex(columns=COLUMNS[self.schema])
361361
df["flags"] = df["flags"] & 0xFF # Apply bitmask
362362
df["side"] = df["side"].str.decode("utf-8")
363363
df["action"] = df["action"].str.decode("utf-8")
@@ -368,6 +368,9 @@ def _cleanup_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
368368
if column in df.columns:
369369
df[column] = df[column].where(df[column] != type_max, np.nan)
370370

371+
# Reorder columns
372+
df = df.reindex(columns=COLUMNS[self.schema])
373+
371374
return df
372375

373376
def _get_index_column(self) -> str:
@@ -834,8 +837,7 @@ def to_df(
834837
835838
"""
836839
df = pd.DataFrame(self.to_ndarray())
837-
df.set_index(self._get_index_column(), inplace=True)
838-
df = self._cleanup_dataframe(df)
840+
df = self._prepare_dataframe(df)
839841

840842
if pretty_ts:
841843
df = self._apply_pretty_ts(df)
@@ -908,4 +910,4 @@ def to_ndarray(self) -> np.ndarray[Any, Any]:
908910
909911
"""
910912
data: bytes = self.reader.read()
911-
return np.frombuffer(data, dtype=STRUCT_MAP[self.schema])
913+
return np.frombuffer(data, dtype=self.dtype)

databento/common/data.py

Lines changed: 125 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,20 @@ def get_deriv_ba_types(level: int) -> List[Tuple[str, Union[type, str]]]:
4444
("ts_event", np.uint64),
4545
]
4646

47+
MBO_MSG: List[Tuple[str, Union[type, str]]] = RECORD_HEADER + [
48+
("order_id", np.uint64),
49+
("price", np.int64),
50+
("size", np.uint32),
51+
("flags", np.int8),
52+
("channel_id", np.uint8),
53+
("action", "S1"), # 1 byte chararray
54+
("side", "S1"), # 1 byte chararray
55+
("ts_recv", np.uint64),
56+
("ts_in_delta", np.int32),
57+
("sequence", np.uint32),
58+
]
4759

48-
MBP_MSG: List[Tuple[str, Union[type, str]]] = [
60+
MBP_MSG: List[Tuple[str, Union[type, str]]] = RECORD_HEADER + [
4961
("price", np.int64),
5062
("size", np.uint32),
5163
("action", "S1"), # 1 byte chararray
@@ -58,32 +70,91 @@ def get_deriv_ba_types(level: int) -> List[Tuple[str, Union[type, str]]]:
5870
]
5971

6072

61-
OHLCV_MSG: List[Tuple[str, Union[type, str]]] = [
73+
OHLCV_MSG: List[Tuple[str, Union[type, str]]] = RECORD_HEADER + [
6274
("open", np.int64),
6375
("high", np.int64),
6476
("low", np.int64),
6577
("close", np.int64),
6678
("volume", np.int64),
6779
]
6880

81+
STATUS_MSG: List[Tuple[str, Union[type, str]]] = RECORD_HEADER + [
82+
("ts_recv", np.uint64),
83+
("group", "S1"), # 1 byte chararray
84+
("trading_status", np.uint8),
85+
("halt_reason", np.uint8),
86+
("trading_event", np.uint8),
87+
]
88+
89+
DEFINITION_MSG: List[Tuple[str, Union[type, str]]] = RECORD_HEADER + [
90+
("ts_recv", np.uint64),
91+
("min_price_increment", np.int64),
92+
("display_factor", np.int64),
93+
("expiration", np.uint64),
94+
("activation", np.uint64),
95+
("high_limit_price", np.int64),
96+
("low_limit_price", np.int64),
97+
("max_price_variation", np.int64),
98+
("trading_reference_price", np.int64),
99+
("unit_of_measure_qty", np.int64),
100+
("min_price_increment_amount", np.int64),
101+
("price_ratio", np.int64),
102+
("inst_attrib_value", np.int32),
103+
("underlying_id", np.uint32),
104+
("cleared_volume", np.int32),
105+
("market_depth_implied", np.int32),
106+
("market_depth", np.int32),
107+
("market_segment_id", np.uint32),
108+
("max_trade_vol", np.uint32),
109+
("min_lot_size", np.int32),
110+
("min_lot_size_block", np.int32),
111+
("min_lot_size_round_lot", np.int32),
112+
("min_trade_vol", np.uint32),
113+
("open_interest_qty", np.int32),
114+
("contract_multiplier", np.int32),
115+
("decay_quantity", np.int32),
116+
("original_contract_size", np.int32),
117+
("related_security_id", np.uint32),
118+
("trading_reference_date", np.uint16),
119+
("appl_id", np.int16),
120+
("maturity_year", np.uint16),
121+
("decay_start_date", np.uint16),
122+
("channel_id", np.uint16),
123+
("currency", "S4"), # 4 byte chararray
124+
("settl_currency", "S4"), # 4 byte chararray
125+
("secsubtype", "S6"), # 6 byte chararray
126+
("symbol", "S22"), # 22 byte chararray
127+
("group", "S21"), # 21 byte chararray
128+
("exchange", "S5"), # 5 byte chararray
129+
("asset", "S7"), # 7 byte chararray
130+
("cfi", "S7"), # 7 byte chararray
131+
("security_type", "S7"), # 7 byte chararray
132+
("unit_of_measure", "S31"), # 31 byte chararray
133+
("underlying", "S21"), # 21 byte chararray
134+
("related", "S21"), # 21 byte chararray
135+
("match_algorithm", "S1"), # 1 byte chararray
136+
("md_security_trading_status", np.uint8),
137+
("main_fraction", np.uint8),
138+
("price_display_format", np.uint8),
139+
("settl_price_type", np.uint8),
140+
("sub_fraction", np.uint8),
141+
("underlying_product", np.uint8),
142+
("security_update_action", "S1"), # 1 byte chararray
143+
("maturity_month", np.uint8),
144+
("maturity_day", np.uint8),
145+
("maturity_week", np.uint8),
146+
("user_defined_instrument", "S1"), # 1 byte chararray
147+
("contract_multiplier_unit", np.int8),
148+
("flow_schedule_type", np.int8),
149+
("tick_rule", np.uint8),
150+
("dummy", "S3"), # 3 byte chararray (Adjustment filler for 8-bytes alignment)
151+
]
152+
69153

70154
STRUCT_MAP: Dict[Schema, List[Tuple[str, Union[type, str]]]] = {
71-
Schema.MBO: RECORD_HEADER
72-
+ [
73-
("order_id", np.uint64),
74-
("price", np.int64),
75-
("size", np.uint32),
76-
("flags", np.int8),
77-
("channel_id", np.uint8),
78-
("action", "S1"), # 1 byte chararray
79-
("side", "S1"), # 1 byte chararray
80-
("ts_recv", np.uint64),
81-
("ts_in_delta", np.int32),
82-
("sequence", np.uint32),
83-
],
84-
Schema.MBP_1: RECORD_HEADER + MBP_MSG + get_deriv_ba_types(0), # 1
85-
Schema.MBP_10: RECORD_HEADER
86-
+ MBP_MSG
155+
Schema.MBO: MBO_MSG,
156+
Schema.MBP_1: MBP_MSG + get_deriv_ba_types(0), # 1
157+
Schema.MBP_10: MBP_MSG
87158
+ get_deriv_ba_types(0) # 1
88159
+ get_deriv_ba_types(1) # 2
89160
+ get_deriv_ba_types(2) # 3
@@ -94,84 +165,14 @@ def get_deriv_ba_types(level: int) -> List[Tuple[str, Union[type, str]]]:
94165
+ get_deriv_ba_types(7) # 8
95166
+ get_deriv_ba_types(8) # 9
96167
+ get_deriv_ba_types(9), # 10
97-
Schema.TBBO: RECORD_HEADER + MBP_MSG + get_deriv_ba_types(0),
98-
Schema.TRADES: RECORD_HEADER + MBP_MSG,
99-
Schema.OHLCV_1S: RECORD_HEADER + OHLCV_MSG,
100-
Schema.OHLCV_1M: RECORD_HEADER + OHLCV_MSG,
101-
Schema.OHLCV_1H: RECORD_HEADER + OHLCV_MSG,
102-
Schema.OHLCV_1D: RECORD_HEADER + OHLCV_MSG,
103-
Schema.STATUS: RECORD_HEADER
104-
+ [
105-
("ts_recv", np.uint64),
106-
("group", "S1"), # 1 byte chararray
107-
("trading_status", np.uint8),
108-
("halt_reason", np.uint8),
109-
("trading_event", np.uint8),
110-
],
111-
Schema.DEFINITION: RECORD_HEADER
112-
+ [
113-
("ts_recv", np.uint64),
114-
("min_price_increment", np.int64),
115-
("display_factor", np.int64),
116-
("expiration", np.uint64),
117-
("activation", np.uint64),
118-
("high_limit_price", np.int64),
119-
("low_limit_price", np.int64),
120-
("max_price_variation", np.int64),
121-
("trading_reference_price", np.int64),
122-
("unit_of_measure_qty", np.int64),
123-
("min_price_increment_amount", np.int64),
124-
("price_ratio", np.int64),
125-
("inst_attrib_value", np.int32),
126-
("underlying_id", np.uint32),
127-
("cleared_volume", np.int32),
128-
("market_depth_implied", np.int32),
129-
("market_depth", np.int32),
130-
("market_segment_id", np.uint32),
131-
("max_trade_vol", np.uint32),
132-
("min_lot_size", np.int32),
133-
("min_lot_size_block", np.int32),
134-
("min_lot_size_round_lot", np.int32),
135-
("min_trade_vol", np.uint32),
136-
("open_interest_qty", np.int32),
137-
("contract_multiplier", np.int32),
138-
("decay_quantity", np.int32),
139-
("original_contract_size", np.int32),
140-
("related_security_id", np.uint32),
141-
("trading_reference_date", np.uint16),
142-
("appl_id", np.int16),
143-
("maturity_year", np.uint16),
144-
("decay_start_date", np.uint16),
145-
("channel_id", np.uint16),
146-
("currency", "S4"), # 4 byte chararray
147-
("settl_currency", "S4"), # 4 byte chararray
148-
("secsubtype", "S6"), # 6 byte chararray
149-
("symbol", "S22"), # 22 byte chararray
150-
("group", "S21"), # 21 byte chararray
151-
("exchange", "S5"), # 5 byte chararray
152-
("asset", "S7"), # 7 byte chararray
153-
("cfi", "S7"), # 7 byte chararray
154-
("security_type", "S7"), # 7 byte chararray
155-
("unit_of_measure", "S31"), # 31 byte chararray
156-
("underlying", "S21"), # 21 byte chararray
157-
("related", "S21"), # 21 byte chararray
158-
("match_algorithm", "S1"), # 1 byte chararray
159-
("md_security_trading_status", np.uint8),
160-
("main_fraction", np.uint8),
161-
("price_display_format", np.uint8),
162-
("settl_price_type", np.uint8),
163-
("sub_fraction", np.uint8),
164-
("underlying_product", np.uint8),
165-
("security_update_action", "S1"), # 1 byte chararray
166-
("maturity_month", np.uint8),
167-
("maturity_day", np.uint8),
168-
("maturity_week", np.uint8),
169-
("user_defined_instrument", "S1"), # 1 byte chararray
170-
("contract_multiplier_unit", np.int8),
171-
("flow_schedule_type", np.int8),
172-
("tick_rule", np.uint8),
173-
("dummy", "S3"), # 3 byte chararray (Adjustment filler for 8-bytes alignment)
174-
],
168+
Schema.TBBO: MBP_MSG + get_deriv_ba_types(0),
169+
Schema.TRADES: MBP_MSG,
170+
Schema.OHLCV_1S: OHLCV_MSG,
171+
Schema.OHLCV_1M: OHLCV_MSG,
172+
Schema.OHLCV_1H: OHLCV_MSG,
173+
Schema.OHLCV_1D: OHLCV_MSG,
174+
Schema.STATUS: STATUS_MSG,
175+
Schema.DEFINITION: DEFINITION_MSG,
175176
Schema.GATEWAY_ERROR: RECORD_HEADER
176177
+ [
177178
("error", "S64"),
@@ -236,7 +237,7 @@ def get_deriv_ba_fields(level: int) -> List[str]:
236237
]
237238

238239

239-
DERIV_HEADER_FIELDS = [
240+
DERIV_HEADER_COLUMNS = [
240241
"ts_event",
241242
"ts_in_delta",
242243
"publisher_id",
@@ -250,6 +251,23 @@ def get_deriv_ba_fields(level: int) -> List[str]:
250251
"sequence",
251252
]
252253

254+
OHLCV_HEADER_COLUMNS = [
255+
"publisher_id",
256+
"product_id",
257+
"open",
258+
"high",
259+
"low",
260+
"close",
261+
"volume",
262+
]
263+
264+
STATUS_COLUMNS = [x for x in np.dtype(STATUS_MSG).names or ()]
265+
STATUS_COLUMNS.remove("ts_recv") # Index
266+
267+
DEFINITION_COLUMNS = [x for x in np.dtype(DEFINITION_MSG).names or ()]
268+
DEFINITION_COLUMNS.remove("ts_recv") # Index
269+
270+
253271
COLUMNS = {
254272
Schema.MBO: [
255273
"ts_event",
@@ -265,8 +283,8 @@ def get_deriv_ba_fields(level: int) -> List[str]:
265283
"size",
266284
"sequence",
267285
],
268-
Schema.MBP_1: DERIV_HEADER_FIELDS + get_deriv_ba_fields(0),
269-
Schema.MBP_10: DERIV_HEADER_FIELDS
286+
Schema.MBP_1: DERIV_HEADER_COLUMNS + get_deriv_ba_fields(0),
287+
Schema.MBP_10: DERIV_HEADER_COLUMNS
270288
+ get_deriv_ba_fields(0)
271289
+ get_deriv_ba_fields(1)
272290
+ get_deriv_ba_fields(2)
@@ -277,6 +295,12 @@ def get_deriv_ba_fields(level: int) -> List[str]:
277295
+ get_deriv_ba_fields(7)
278296
+ get_deriv_ba_fields(8)
279297
+ get_deriv_ba_fields(9),
280-
Schema.TBBO: DERIV_HEADER_FIELDS + get_deriv_ba_fields(0),
281-
Schema.TRADES: DERIV_HEADER_FIELDS,
298+
Schema.TBBO: DERIV_HEADER_COLUMNS + get_deriv_ba_fields(0),
299+
Schema.TRADES: DERIV_HEADER_COLUMNS,
300+
Schema.OHLCV_1S: OHLCV_HEADER_COLUMNS,
301+
Schema.OHLCV_1M: OHLCV_HEADER_COLUMNS,
302+
Schema.OHLCV_1H: OHLCV_HEADER_COLUMNS,
303+
Schema.OHLCV_1D: OHLCV_HEADER_COLUMNS,
304+
Schema.STATUS: STATUS_COLUMNS,
305+
Schema.DEFINITION: DEFINITION_COLUMNS,
282306
}

databento/historical/api/timeseries.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def get_range(
127127
("schema", str(schema_valid)),
128128
("stype_in", str(stype_in_valid)),
129129
("stype_out", str(validate_enum(stype_out, SType, "stype_out"))),
130-
("encoding", str(Encoding.DBN)), # always request dbn
131-
("compression", str(Compression.ZSTD)), # always request zstd
130+
("encoding", str(Encoding.DBN)), # Always request dbn
131+
("compression", str(Compression.ZSTD)), # Always request zstd
132132
]
133133

134134
# Optional Parameters
@@ -263,8 +263,8 @@ async def get_range_async(
263263
("schema", str(schema_valid)),
264264
("stype_in", str(stype_in_valid)),
265265
("stype_out", str(validate_enum(stype_out, SType, "stype_out"))),
266-
("encoding", str(Encoding.DBN)), # always request dbn
267-
("compression", str(Compression.ZSTD)), # always request zstd
266+
("encoding", str(Encoding.DBN)), # Always request dbn
267+
("compression", str(Compression.ZSTD)), # Always request zstd
268268
]
269269

270270
if limit is not None:

0 commit comments

Comments
 (0)