11from __future__ import annotations
22
33import abc
4- import datetime as dt
54import itertools
65import logging
76from collections .abc import Generator
3029from databento .common .data import SCHEMA_DTYPES_MAP
3130from databento .common .data import SCHEMA_STRUCT_MAP
3231from databento .common .error import BentoError
33- from databento .common .symbology import InstrumentIdMappingInterval
32+ from databento .common .symbology import InstrumentMap
3433from databento .common .validation import validate_file_write_path
3534from databento .common .validation import validate_maybe_enum
3635from databento .live import DBNRecord
@@ -98,7 +97,6 @@ def format_dataframe(
9897 schema : Schema ,
9998 pretty_px : bool ,
10099 pretty_ts : bool ,
101- instrument_id_index : dict [dt .date , dict [int , str ]],
102100) -> pd .DataFrame :
103101 struct = SCHEMA_STRUCT_MAP [schema ]
104102
@@ -122,13 +120,6 @@ def format_dataframe(
122120 index_column = "ts_event" if schema .value .startswith ("ohlcv" ) else "ts_recv"
123121 df .set_index (index_column , inplace = True )
124122
125- if instrument_id_index :
126- df_index = df .index if pretty_ts else pd .to_datetime (df .index , utc = True )
127- dates = [ts .date () for ts in df_index ]
128- df ["symbol" ] = [
129- instrument_id_index [dates [i ]][p ] for i , p in enumerate (df ["instrument_id" ])
130- ]
131-
132123 return df
133124
134125
@@ -402,11 +393,7 @@ def __init__(self, data_source: DataSource) -> None:
402393 metadata_bytes .getvalue (),
403394 )
404395
405- # This is populated when _map_symbols is called
406- self ._instrument_id_index : dict [
407- dt .date ,
408- dict [int , str ],
409- ] = {}
396+ self ._instrument_map = InstrumentMap ()
410397
411398 def __iter__ (self ) -> Generator [DBNRecord , None , None ]:
412399 reader = self .reader
@@ -422,6 +409,8 @@ def __iter__(self) -> Generator[DBNRecord, None, None]:
422409 for record in records :
423410 if isinstance (record , databento_dbn .Metadata ):
424411 continue
412+ if isinstance (record , databento_dbn .SymbolMappingMsg ):
413+ self ._instrument_map .insert_symbol_mapping_msg (record )
425414 yield record
426415 else :
427416 if len (decoder .buffer ()) > 0 :
@@ -434,38 +423,6 @@ def __repr__(self) -> str:
434423 name = self .__class__ .__name__
435424 return f"<{ name } (schema={ self .schema } )>"
436425
437- def _build_instrument_id_index (self ) -> dict [dt .date , dict [int , str ]]:
438- intervals : list [InstrumentIdMappingInterval ] = []
439- for raw_symbol , i in self .mappings .items ():
440- for row in i :
441- symbol = row ["symbol" ]
442- if symbol == "" :
443- continue
444- intervals .append (
445- InstrumentIdMappingInterval (
446- start_date = row ["start_date" ],
447- end_date = row ["end_date" ],
448- raw_symbol = raw_symbol ,
449- instrument_id = int (row ["symbol" ]),
450- ),
451- )
452-
453- instrument_id_index : dict [dt .date , dict [int , str ]] = {}
454- for interval in intervals :
455- for ts in pd .date_range (
456- start = interval .start_date ,
457- end = interval .end_date ,
458- # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.date_range.html
459- ** {"inclusive" if pd .__version__ >= "1.4.0" else "closed" : "left" },
460- ):
461- d : dt .date = ts .date ()
462- date_map : dict [int , str ] = instrument_id_index .get (d , {})
463- if not date_map :
464- instrument_id_index [d ] = date_map
465- date_map [interval .instrument_id ] = interval .raw_symbol
466-
467- return instrument_id_index
468-
469426 @property
470427 def compression (self ) -> Compression :
471428 """
@@ -813,13 +770,20 @@ def request_symbology(self, client: Historical) -> dict[str, Any]:
813770 date range.
814771
815772 """
773+ if self .end is None :
774+ end_date = None
775+ elif self .start .date () == self .end .date ():
776+ end_date = (self .start + pd .Timedelta (days = 1 )).date ()
777+ else :
778+ end_date = self .end
779+
816780 return client .symbology .resolve (
817781 dataset = self .dataset ,
818782 symbols = self .symbols ,
819783 stype_in = self .stype_in ,
820784 stype_out = self .stype_out ,
821785 start_date = self .start .date (),
822- end_date = self . end . date () if self . end else None ,
786+ end_date = end_date ,
823787 )
824788
825789 def to_csv (
@@ -882,7 +846,7 @@ def to_df(
882846 self ,
883847 pretty_px : bool = ...,
884848 pretty_ts : bool = ...,
885- map_symbols : bool | None = ...,
849+ map_symbols : bool = ...,
886850 schema : Schema | str | None = ...,
887851 count : None = ...,
888852 ) -> pd .DataFrame :
@@ -893,7 +857,7 @@ def to_df(
893857 self ,
894858 pretty_px : bool = ...,
895859 pretty_ts : bool = ...,
896- map_symbols : bool | None = ...,
860+ map_symbols : bool = ...,
897861 schema : Schema | str | None = ...,
898862 count : int = ...,
899863 ) -> DataFrameIterator :
@@ -903,7 +867,7 @@ def to_df(
903867 self ,
904868 pretty_px : bool = True ,
905869 pretty_ts : bool = True ,
906- map_symbols : bool | None = None ,
870+ map_symbols : bool = True ,
907871 schema : Schema | str | None = None ,
908872 count : int | None = None ,
909873 ) -> pd .DataFrame | DataFrameIterator :
@@ -950,29 +914,22 @@ def to_df(
950914 raise ValueError ("a schema must be specified for mixed DBN data" )
951915 schema = self .schema
952916
953- if map_symbols is None :
954- map_symbols = self .stype_out == SType .INSTRUMENT_ID
955-
956- if map_symbols :
957- if self .stype_out != SType .INSTRUMENT_ID :
958- raise ValueError (
959- "`map_symbols` is not supported when `stype_out` is not 'instrument_id'" ,
960- )
961- if not self ._instrument_id_index :
962- self ._instrument_id_index = self ._build_instrument_id_index ()
963-
964917 if count is None :
965918 records = iter ([self .to_ndarray (schema )])
966919 else :
967920 records = self .to_ndarray (schema , count )
968921
922+ if map_symbols :
923+ self ._instrument_map .insert_metadata (self .metadata )
924+
969925 df_iter = DataFrameIterator (
970926 records = records ,
971927 schema = schema ,
972928 count = count ,
929+ instrument_map = self ._instrument_map ,
973930 pretty_px = pretty_px ,
974931 pretty_ts = pretty_ts ,
975- instrument_id_index = self . _instrument_id_index if map_symbols else {} ,
932+ map_symbols = map_symbols ,
976933 )
977934
978935 if count is None :
@@ -1168,30 +1125,38 @@ def __init__(
11681125 records : Iterator [np .ndarray [Any , Any ]],
11691126 count : int | None ,
11701127 schema : Schema ,
1128+ instrument_map : InstrumentMap ,
11711129 pretty_px : bool = True ,
11721130 pretty_ts : bool = True ,
1173- instrument_id_index : dict [ dt . date , dict [ int , str ]] | None = None ,
1131+ map_symbols : bool = True ,
11741132 ):
11751133 self ._records = records
11761134 self ._schema = schema
11771135 self ._count = count
11781136 self ._pretty_px = pretty_px
11791137 self ._pretty_ts = pretty_ts
1180- self ._instrument_id_index = (
1181- instrument_id_index if instrument_id_index is not None else {}
1182- )
1138+ self ._map_symbols = map_symbols
1139+ self ._instrument_map = instrument_map
11831140
11841141 def __iter__ (self ) -> DataFrameIterator :
11851142 return self
11861143
11871144 def __next__ (self ) -> pd .DataFrame :
1188- return format_dataframe (
1145+ df = format_dataframe (
11891146 pd .DataFrame (
11901147 next (self ._records ),
11911148 columns = SCHEMA_COLUMNS [self ._schema ],
11921149 ),
11931150 schema = self ._schema ,
11941151 pretty_px = self ._pretty_px ,
11951152 pretty_ts = self ._pretty_ts ,
1196- instrument_id_index = self ._instrument_id_index ,
11971153 )
1154+
1155+ if self ._map_symbols :
1156+ df_index = df .index if self ._pretty_ts else pd .to_datetime (df .index , utc = True )
1157+ dates = [ts .date () for ts in df_index ]
1158+ df ["symbol" ] = [
1159+ self ._instrument_map .resolve (inst , dates [i ]) for i , inst in enumerate (df ["instrument_id" ])
1160+ ]
1161+
1162+ return df
0 commit comments