Skip to content

Commit 935e67f

Browse files
committed
MOD: Client symbology mapping improvement
1 parent d8363ca commit 935e67f

File tree

7 files changed

+861
-131
lines changed

7 files changed

+861
-131
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@
22

33
## 0.21.0 - TBD
44

5+
#### Enhancements
6+
- Added `map_symbols` support for DBN data generated by the `Live` client
7+
58
#### Bug fixes
69
- Fixed an issue where `DBNStore.from_bytes` did not rewind seekable buffers
10+
- Fixed an issue where the `DBNStore` would not map symbols with input symbology of `SType.INSTRUMENT_ID`
11+
- Fixed an issue with `DBNStore.request_symbology` when the DBN metadata's start date and end date were the same
712

813
## 0.20.0 - 2023-09-21
914

databento/common/dbnstore.py

Lines changed: 34 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import abc
4-
import datetime as dt
54
import itertools
65
import logging
76
from collections.abc import Generator
@@ -30,7 +29,7 @@
3029
from databento.common.data import SCHEMA_DTYPES_MAP
3130
from databento.common.data import SCHEMA_STRUCT_MAP
3231
from databento.common.error import BentoError
33-
from databento.common.symbology import InstrumentIdMappingInterval
32+
from databento.common.symbology import InstrumentMap
3433
from databento.common.validation import validate_file_write_path
3534
from databento.common.validation import validate_maybe_enum
3635
from databento.live import DBNRecord
@@ -98,7 +97,6 @@ def format_dataframe(
9897
schema: Schema,
9998
pretty_px: bool,
10099
pretty_ts: bool,
101-
instrument_id_index: dict[dt.date, dict[int, str]],
102100
) -> pd.DataFrame:
103101
struct = SCHEMA_STRUCT_MAP[schema]
104102

@@ -122,13 +120,6 @@ def format_dataframe(
122120
index_column = "ts_event" if schema.value.startswith("ohlcv") else "ts_recv"
123121
df.set_index(index_column, inplace=True)
124122

125-
if instrument_id_index:
126-
df_index = df.index if pretty_ts else pd.to_datetime(df.index, utc=True)
127-
dates = [ts.date() for ts in df_index]
128-
df["symbol"] = [
129-
instrument_id_index[dates[i]][p] for i, p in enumerate(df["instrument_id"])
130-
]
131-
132123
return df
133124

134125

@@ -402,11 +393,7 @@ def __init__(self, data_source: DataSource) -> None:
402393
metadata_bytes.getvalue(),
403394
)
404395

405-
# This is populated when _map_symbols is called
406-
self._instrument_id_index: dict[
407-
dt.date,
408-
dict[int, str],
409-
] = {}
396+
self._instrument_map = InstrumentMap()
410397

411398
def __iter__(self) -> Generator[DBNRecord, None, None]:
412399
reader = self.reader
@@ -422,6 +409,8 @@ def __iter__(self) -> Generator[DBNRecord, None, None]:
422409
for record in records:
423410
if isinstance(record, databento_dbn.Metadata):
424411
continue
412+
if isinstance(record, databento_dbn.SymbolMappingMsg):
413+
self._instrument_map.insert_symbol_mapping_msg(record)
425414
yield record
426415
else:
427416
if len(decoder.buffer()) > 0:
@@ -434,38 +423,6 @@ def __repr__(self) -> str:
434423
name = self.__class__.__name__
435424
return f"<{name}(schema={self.schema})>"
436425

437-
def _build_instrument_id_index(self) -> dict[dt.date, dict[int, str]]:
438-
intervals: list[InstrumentIdMappingInterval] = []
439-
for raw_symbol, i in self.mappings.items():
440-
for row in i:
441-
symbol = row["symbol"]
442-
if symbol == "":
443-
continue
444-
intervals.append(
445-
InstrumentIdMappingInterval(
446-
start_date=row["start_date"],
447-
end_date=row["end_date"],
448-
raw_symbol=raw_symbol,
449-
instrument_id=int(row["symbol"]),
450-
),
451-
)
452-
453-
instrument_id_index: dict[dt.date, dict[int, str]] = {}
454-
for interval in intervals:
455-
for ts in pd.date_range(
456-
start=interval.start_date,
457-
end=interval.end_date,
458-
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.date_range.html
459-
**{"inclusive" if pd.__version__ >= "1.4.0" else "closed": "left"},
460-
):
461-
d: dt.date = ts.date()
462-
date_map: dict[int, str] = instrument_id_index.get(d, {})
463-
if not date_map:
464-
instrument_id_index[d] = date_map
465-
date_map[interval.instrument_id] = interval.raw_symbol
466-
467-
return instrument_id_index
468-
469426
@property
470427
def compression(self) -> Compression:
471428
"""
@@ -813,13 +770,20 @@ def request_symbology(self, client: Historical) -> dict[str, Any]:
813770
date range.
814771
815772
"""
773+
if self.end is None:
774+
end_date = None
775+
elif self.start.date() == self.end.date():
776+
end_date = (self.start + pd.Timedelta(days=1)).date()
777+
else:
778+
end_date = self.end
779+
816780
return client.symbology.resolve(
817781
dataset=self.dataset,
818782
symbols=self.symbols,
819783
stype_in=self.stype_in,
820784
stype_out=self.stype_out,
821785
start_date=self.start.date(),
822-
end_date=self.end.date() if self.end else None,
786+
end_date=end_date,
823787
)
824788

825789
def to_csv(
@@ -882,7 +846,7 @@ def to_df(
882846
self,
883847
pretty_px: bool = ...,
884848
pretty_ts: bool = ...,
885-
map_symbols: bool | None = ...,
849+
map_symbols: bool = ...,
886850
schema: Schema | str | None = ...,
887851
count: None = ...,
888852
) -> pd.DataFrame:
@@ -893,7 +857,7 @@ def to_df(
893857
self,
894858
pretty_px: bool = ...,
895859
pretty_ts: bool = ...,
896-
map_symbols: bool | None = ...,
860+
map_symbols: bool = ...,
897861
schema: Schema | str | None = ...,
898862
count: int = ...,
899863
) -> DataFrameIterator:
@@ -903,7 +867,7 @@ def to_df(
903867
self,
904868
pretty_px: bool = True,
905869
pretty_ts: bool = True,
906-
map_symbols: bool | None = None,
870+
map_symbols: bool = True,
907871
schema: Schema | str | None = None,
908872
count: int | None = None,
909873
) -> pd.DataFrame | DataFrameIterator:
@@ -950,29 +914,22 @@ def to_df(
950914
raise ValueError("a schema must be specified for mixed DBN data")
951915
schema = self.schema
952916

953-
if map_symbols is None:
954-
map_symbols = self.stype_out == SType.INSTRUMENT_ID
955-
956-
if map_symbols:
957-
if self.stype_out != SType.INSTRUMENT_ID:
958-
raise ValueError(
959-
"`map_symbols` is not supported when `stype_out` is not 'instrument_id'",
960-
)
961-
if not self._instrument_id_index:
962-
self._instrument_id_index = self._build_instrument_id_index()
963-
964917
if count is None:
965918
records = iter([self.to_ndarray(schema)])
966919
else:
967920
records = self.to_ndarray(schema, count)
968921

922+
if map_symbols:
923+
self._instrument_map.insert_metadata(self.metadata)
924+
969925
df_iter = DataFrameIterator(
970926
records=records,
971927
schema=schema,
972928
count=count,
929+
instrument_map=self._instrument_map,
973930
pretty_px=pretty_px,
974931
pretty_ts=pretty_ts,
975-
instrument_id_index=self._instrument_id_index if map_symbols else {},
932+
map_symbols=map_symbols,
976933
)
977934

978935
if count is None:
@@ -1168,30 +1125,38 @@ def __init__(
11681125
records: Iterator[np.ndarray[Any, Any]],
11691126
count: int | None,
11701127
schema: Schema,
1128+
instrument_map: InstrumentMap,
11711129
pretty_px: bool = True,
11721130
pretty_ts: bool = True,
1173-
instrument_id_index: dict[dt.date, dict[int, str]] | None = None,
1131+
map_symbols: bool = True,
11741132
):
11751133
self._records = records
11761134
self._schema = schema
11771135
self._count = count
11781136
self._pretty_px = pretty_px
11791137
self._pretty_ts = pretty_ts
1180-
self._instrument_id_index = (
1181-
instrument_id_index if instrument_id_index is not None else {}
1182-
)
1138+
self._map_symbols = map_symbols
1139+
self._instrument_map = instrument_map
11831140

11841141
def __iter__(self) -> DataFrameIterator:
11851142
return self
11861143

11871144
def __next__(self) -> pd.DataFrame:
1188-
return format_dataframe(
1145+
df = format_dataframe(
11891146
pd.DataFrame(
11901147
next(self._records),
11911148
columns=SCHEMA_COLUMNS[self._schema],
11921149
),
11931150
schema=self._schema,
11941151
pretty_px=self._pretty_px,
11951152
pretty_ts=self._pretty_ts,
1196-
instrument_id_index=self._instrument_id_index,
11971153
)
1154+
1155+
if self._map_symbols:
1156+
df_index = df.index if self._pretty_ts else pd.to_datetime(df.index, utc=True)
1157+
dates = [ts.date() for ts in df_index]
1158+
df["symbol"] = [
1159+
self._instrument_map.resolve(inst, dates[i]) for i, inst in enumerate(df["instrument_id"])
1160+
]
1161+
1162+
return df

0 commit comments

Comments
 (0)