Skip to content

Commit de86e46

Browse files
authored
refactor: introduce BaseDataHandler and unify fetch interface (#1958)
* refactor: introduce BaseDataHandler and unify fetch interface * refactor: include data_key in seg_kwargs and simplify segments loop * refactor: default data_key to BaseDataHandler.DK_I in _get_df_by_key * style: fix indentation and remove extra blank lines in data handlers * refactor: use BaseDataHandler.DK_I as default data_key * docs: fix BaseDataHandler docstring grammar and formatting * refactor: remove unused **kwargs from storage fetch methods * docs: refine BaseDataHandler and DataHandler docstrings * refactor: rename BaseDataHandler to DataHandlerABC, update type hints * feat: add flt_col to TSDatasetH and list-to-slice conversion in storage * lint * comment
1 parent ba8b6cc commit de86e46

File tree

4 files changed

+126
-61
lines changed

4 files changed

+126
-61
lines changed

qlib/data/dataset/__init__.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -226,23 +226,20 @@ def prepare(
226226
------
227227
NotImplementedError:
228228
"""
229-
logger = get_module_logger("DatasetH")
230-
seg_kwargs = {"col_set": col_set}
229+
seg_kwargs = {"col_set": col_set, "data_key": data_key}
231230
seg_kwargs.update(kwargs)
232-
if "data_key" in getfullargspec(self.handler.fetch).args:
233-
seg_kwargs["data_key"] = data_key
234-
else:
235-
logger.info(f"data_key[{data_key}] is ignored.")
236231

237232
# Conflictions may happen here
238233
# - The fetched data and the segment key may both be string
239234
# To resolve the confliction
240235
# - The segment name will have higher priorities
241236

242237
# 1) Use it as segment name first
238+
# 1.1) directly fetch split like "train" "valid" "test"
243239
if isinstance(segments, str) and segments in self.segments:
244240
return self._prepare_seg(self.segments[segments], **seg_kwargs)
245241

242+
# 1.2) fetch multiple splits like ["train", "valid"] ["train", "valid", "test"]
246243
if isinstance(segments, (list, tuple)) and all(seg in self.segments for seg in segments):
247244
return [self._prepare_seg(self.segments[seg], **seg_kwargs) for seg in segments]
248245

@@ -262,7 +259,7 @@ def get_max_time(segments):
262259
def _get_extrema(segments, idx: int, cmp: Callable, key_func=pd.Timestamp):
263260
"""it will act like sort and return the max value or None"""
264261
candidate = None
265-
for k, seg in segments.items():
262+
for _, seg in segments.items():
266263
point = seg[idx]
267264
if point is None:
268265
# None indicates unbounded, return directly
@@ -376,6 +373,8 @@ def __init__(
376373
ffill with previous samples first and fill with later samples second
377374
flt_data : pd.Series
378375
a column of data(True or False) to filter data. Its index order is <"datetime", "instrument">
376+
This feature is essential because:
377+
- We want some sample not included due to label-based filtering, but we can't filter them at the beginning due to the features is still important in the feature.
379378
None:
380379
kepp all data
381380
@@ -661,8 +660,9 @@ class TSDatasetH(DatasetH):
661660

662661
DEFAULT_STEP_LEN = 30
663662

664-
def __init__(self, step_len=DEFAULT_STEP_LEN, **kwargs):
663+
def __init__(self, step_len=DEFAULT_STEP_LEN, flt_col: Optional[str] = None, **kwargs):
665664
self.step_len = step_len
665+
self.flt_col = flt_col
666666
super().__init__(**kwargs)
667667

668668
def config(self, **kwargs):
@@ -693,10 +693,10 @@ def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
693693
dtype = kwargs.pop("dtype", None)
694694
if not isinstance(slc, slice):
695695
slc = slice(*slc)
696-
start, end = slc.start, slc.stop
697-
flt_col = kwargs.pop("flt_col", None)
698-
# TSDatasetH will retrieve more data for complete time-series
696+
if (flt_col := kwargs.pop("flt_col", None)) is None:
697+
flt_col = self.flt_col
699698

699+
# TSDatasetH will retrieve more data for complete time-series
700700
ext_slice = self._extend_slice(slc, self.cal, self.step_len)
701701
data = super()._prepare_seg(ext_slice, **kwargs)
702702

@@ -710,8 +710,8 @@ def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
710710

711711
tsds = TSDataSampler(
712712
data=data,
713-
start=start,
714-
end=end,
713+
start=slc.start,
714+
end=slc.stop,
715715
step_len=self.step_len,
716716
dtype=dtype,
717717
flt_data=flt_data,

qlib/data/dataset/handler.py

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Licensed under the MIT License.
33

44
# coding=utf-8
5+
from abc import abstractmethod
56
import warnings
67
from typing import Callable, Union, Tuple, List, Iterator, Optional
78

@@ -19,9 +20,59 @@
1920
from . import loader as data_loader_module
2021

2122

22-
# TODO: A more general handler interface which does not relies on internal pd.DataFrame is needed.
23-
class DataHandler(Serializable):
23+
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]
24+
25+
26+
class DataHandlerABC(Serializable):
27+
"""
28+
Interface for data handler.
29+
30+
This class does not assume the internal data structure of the data handler.
31+
It only defines the interface for external users (uses DataFrame as the internal data structure).
32+
33+
In the future, the data handler's more detailed implementation should be refactored. Here are some guidelines:
34+
35+
It covers several components:
36+
37+
- [data loader] -> internal representation of the data -> data preprocessing -> interface adaptor for the fetch interface
38+
- The workflow to combine them all:
39+
The workflow may be very complicated. DataHandlerLP is one of the practices, but it can't satisfy all the requirements.
40+
So leaving the flexibility to the user to implement the workflow is a more reasonable choice.
41+
"""
42+
43+
def __init__(self, *args, **kwargs):
44+
"""
45+
We should define how to get ready for the fetching.
46+
"""
47+
super().__init__(*args, **kwargs)
48+
49+
CS_ALL = "__all" # return all columns with single-level index column
50+
CS_RAW = "__raw" # return raw data with multi-level index column
51+
52+
# data key
53+
DK_R: DATA_KEY_TYPE = "raw"
54+
DK_I: DATA_KEY_TYPE = "infer"
55+
DK_L: DATA_KEY_TYPE = "learn"
56+
57+
@abstractmethod
58+
def fetch(
59+
self,
60+
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
61+
level: Union[str, int] = "datetime",
62+
col_set: Union[str, List[str]] = CS_ALL,
63+
data_key: DATA_KEY_TYPE = DK_I,
64+
) -> pd.DataFrame:
65+
pass
66+
67+
68+
class DataHandler(DataHandlerABC):
2469
"""
70+
The motivation of DataHandler:
71+
72+
- It provides an implementation of BaseDataHandler that we implement with:
73+
- Handling responses with an internal loaded DataFrame
74+
- The DataFrame is loaded by a data loader.
75+
2576
The steps to using a handler
2677
1. initialized data handler (call by `init`).
2778
2. use the data.
@@ -144,16 +195,14 @@ def setup_data(self, enable_cache: bool = False):
144195
self._data = lazy_sort_index(self.data_loader.load(self.instruments, self.start_time, self.end_time))
145196
# TODO: cache
146197

147-
CS_ALL = "__all" # return all columns with single-level index column
148-
CS_RAW = "__raw" # return raw data with multi-level index column
149-
150198
def fetch(
151199
self,
152200
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
153201
level: Union[str, int] = "datetime",
154-
col_set: Union[str, List[str]] = CS_ALL,
202+
col_set: Union[str, List[str]] = DataHandlerABC.CS_ALL,
203+
data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I,
155204
squeeze: bool = False,
156-
proc_func: Callable = None,
205+
proc_func: Optional[Callable] = None,
157206
) -> pd.DataFrame:
158207
"""
159208
fetch data from underlying data source
@@ -216,6 +265,8 @@ def fetch(
216265
-------
217266
pd.DataFrame.
218267
"""
268+
# DataHandler is an example with only one dataframe, so data_key is not used.
269+
_ = data_key # avoid linting errors (e.g., unused-argument)
219270
return self._fetch_data(
220271
data_storage=self._data,
221272
selector=selector,
@@ -230,7 +281,7 @@ def _fetch_data(
230281
data_storage,
231282
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
232283
level: Union[str, int] = "datetime",
233-
col_set: Union[str, List[str]] = CS_ALL,
284+
col_set: Union[str, List[str]] = DataHandlerABC.CS_ALL,
234285
squeeze: bool = False,
235286
proc_func: Callable = None,
236287
):
@@ -261,16 +312,9 @@ def _fetch_data(
261312
data_df = fetch_df_by_col(data_df, col_set)
262313
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
263314
elif isinstance(data_storage, BaseHandlerStorage):
264-
if not data_storage.is_proc_func_supported():
265-
if proc_func is not None:
266-
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
267-
data_df = data_storage.fetch(
268-
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig
269-
)
270-
else:
271-
data_df = data_storage.fetch(
272-
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
273-
)
315+
if proc_func is not None:
316+
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
317+
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
274318
else:
275319
raise TypeError(f"data_storage should be pd.DataFrame|HashingStockStorage, not {type(data_storage)}")
276320

@@ -282,7 +326,7 @@ def _fetch_data(
282326
data_df = data_df.reset_index(level=level, drop=True)
283327
return data_df
284328

285-
def get_cols(self, col_set=CS_ALL) -> list:
329+
def get_cols(self, col_set=DataHandlerABC.CS_ALL) -> list:
286330
"""
287331
get the column names
288332
@@ -336,11 +380,12 @@ def get_range_iterator(
336380
yield cur_date, self.fetch(selector, **kwargs)
337381

338382

339-
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]
340-
341-
342383
class DataHandlerLP(DataHandler):
343384
"""
385+
Motivation:
386+
- For the case that we hope using different processor workflows for learning and inference;
387+
388+
344389
DataHandler with **(L)earnable (P)rocessor**
345390
346391
This handler will produce three pieces of data in pd.DataFrame format.
@@ -374,12 +419,8 @@ class DataHandlerLP(DataHandler):
374419
_infer: pd.DataFrame # data for inference
375420
_learn: pd.DataFrame # data for learning models
376421

377-
# data key
378-
DK_R: DATA_KEY_TYPE = "raw"
379-
DK_I: DATA_KEY_TYPE = "infer"
380-
DK_L: DATA_KEY_TYPE = "learn"
381422
# map data_key to attribute name
382-
ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"}
423+
ATTR_MAP = {DataHandler.DK_R: "_data", DataHandler.DK_I: "_infer", DataHandler.DK_L: "_learn"}
383424

384425
# process type
385426
PTYPE_I = "independent"
@@ -622,7 +663,7 @@ def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):
622663

623664
# TODO: Be able to cache handler data. Save the memory for data processing
624665

625-
def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DK_I) -> pd.DataFrame:
666+
def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> pd.DataFrame:
626667
if data_key == self.DK_R and self.drop_raw:
627668
raise AttributeError(
628669
"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data"
@@ -635,7 +676,7 @@ def fetch(
635676
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
636677
level: Union[str, int] = "datetime",
637678
col_set=DataHandler.CS_ALL,
638-
data_key: DATA_KEY_TYPE = DK_I,
679+
data_key: DATA_KEY_TYPE = DataHandler.DK_I,
639680
squeeze: bool = False,
640681
proc_func: Callable = None,
641682
) -> pd.DataFrame:
@@ -669,7 +710,7 @@ def fetch(
669710
proc_func=proc_func,
670711
)
671712

672-
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DK_I) -> list:
713+
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> list:
673714
"""
674715
get the column names
675716

qlib/data/dataset/storage.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from abc import abstractmethod
12
import pandas as pd
23
import numpy as np
34

45
from .handler import DataHandler
5-
from typing import Union, List, Callable
6+
from typing import Union, List
7+
from qlib.log import get_module_logger
68

79
from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
810

@@ -14,14 +16,13 @@ class BaseHandlerStorage:
1416
- If users want to use custom data storage, they should define subclass inherited BaseHandlerStorage, and implement the following method
1517
"""
1618

19+
@abstractmethod
1720
def fetch(
1821
self,
19-
selector: Union[pd.Timestamp, slice, str, list] = slice(None, None),
22+
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
2023
level: Union[str, int] = "datetime",
2124
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
2225
fetch_orig: bool = True,
23-
proc_func: Callable = None,
24-
**kwargs,
2526
) -> pd.DataFrame:
2627
"""fetch data from the data storage
2728
@@ -41,8 +42,6 @@ def fetch(
4142
select several sets of meaningful columns, the returned data has multiple level
4243
fetch_orig : bool
4344
Return the original data instead of copy if possible.
44-
proc_func: Callable
45-
please refer to the doc of DataHandler.fetch
4645
4746
Returns
4847
-------
@@ -51,13 +50,40 @@ def fetch(
5150
"""
5251
raise NotImplementedError("fetch is method not implemented!")
5352

54-
@staticmethod
55-
def from_df(df: pd.DataFrame):
56-
raise NotImplementedError("from_df method is not implemented!")
5753

58-
def is_proc_func_supported(self):
59-
"""whether the arg `proc_func` in `fetch` method is supported."""
60-
raise NotImplementedError("is_proc_func_supported method is not implemented!")
54+
class NaiveDFStorage(BaseHandlerStorage):
55+
"""Naive data storage for datahandler
56+
- NaiveDFStorage is a naive data storage for datahandler
57+
- NaiveDFStorage will input a pandas.DataFrame as and provide interface support for fetching data
58+
"""
59+
60+
def __init__(self, df: pd.DataFrame):
61+
self.df = df
62+
63+
def fetch(
64+
self,
65+
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
66+
level: Union[str, int] = "datetime",
67+
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
68+
fetch_orig: bool = True,
69+
) -> pd.DataFrame:
70+
71+
# Following conflicts may occur
72+
# - Does [20200101", "20210101"] mean selecting this slice or these two days?
73+
# To solve this issue
74+
# - slice have higher priorities (except when level is none)
75+
if isinstance(selector, (tuple, list)) and level is not None:
76+
# when level is None, the argument will be passed in directly
77+
# we don't have to convert it into slice
78+
try:
79+
selector = slice(*selector)
80+
except ValueError:
81+
get_module_logger("DataHandlerLP").info(f"Fail to converting to query to slice. It will used directly")
82+
83+
data_df = self.df
84+
data_df = fetch_df_by_col(data_df, col_set)
85+
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=fetch_orig)
86+
return data_df
6187

6288

6389
class HashingStockStorage(BaseHandlerStorage):
@@ -142,7 +168,7 @@ def _fetch_hash_df_by_stock(self, selector, level):
142168

143169
def fetch(
144170
self,
145-
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
171+
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
146172
level: Union[str, int] = "datetime",
147173
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
148174
fetch_orig: bool = True,
@@ -164,7 +190,3 @@ def fetch(
164190
return fetch_stock_df_list[0]
165191
else:
166192
return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)
167-
168-
def is_proc_func_supported(self):
169-
"""the arg `proc_func` in `fetch` method is not supported in HashingStockStorage"""
170-
return False

qlib/model/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,9 @@ def __init__(
240240
self.train_func = train_func
241241
self._call_in_subproc = call_in_subproc
242242

243-
def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
243+
def train(
244+
self, tasks: list, train_func: Optional[Callable] = None, experiment_name: Optional[str] = None, **kwargs
245+
) -> List[Recorder]:
244246
"""
245247
Given a list of `tasks` and return a list of trained Recorder. The order can be guaranteed.
246248

0 commit comments

Comments
 (0)