Skip to content

Commit fbba768

Browse files
authored
fixed a problem with multi index caused by the default value of groupkey (#1917)
* fixed a problem with multi index caused by the default value of groupkey * modify group_key default value * limit pandas verion * format with black * fix docs error * fix docs error * fixed bugs caused by pandas upgrade * remove needless code * reformat with black * limit version & add docs
1 parent df557d2 commit fbba768

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+153
-98
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,14 @@ python run_all_model.py run 10
462462

463463
It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
464464
465+
### Break change
466+
In `pandas`, `group_key` is one of the parameters of the `groupby` method. From version 1.5 to 2.0 of `pandas`, the default value of `group_key` has been changed from `no default` to `True`, which will cause qlib to report an error during operation. So we set `group_key=False`, but it doesn't guarantee that some programmes will run correctly, including:
467+
* qlib\examples\rl_order_execution\scripts\gen_training_orders.py
468+
* qlib\examples\benchmarks\TRA\src\dataset.MTSDatasetH.py
469+
* qlib\examples\benchmarks\TFT\tft.py
470+
471+
472+
465473
## [Adapting to Market Dynamics](examples/benchmarks_dynamic)
466474

467475
Due to the non-stationary nature of the environment of the financial market, the data distribution may change in different periods, which makes the performance of models build on training data decays in the future test data.

examples/benchmarks/TFT/libs/tft_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def _batch_sampled_data(self, data, max_samples):
599599
print("Getting valid sampling locations.")
600600
valid_sampling_locations = []
601601
split_data_map = {}
602-
for identifier, df in data.groupby(id_col):
602+
for identifier, df in data.groupby(id_col, group_key=False):
603603
print("Getting locations for {}".format(identifier))
604604
num_entries = len(df)
605605
if num_entries >= self.time_steps:
@@ -678,7 +678,7 @@ def _batch_single_entity(input_data):
678678
input_cols = [tup[0] for tup in self.column_definition if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
679679

680680
data_map = {}
681-
for _, sliced in data.groupby(id_col):
681+
for _, sliced in data.groupby(id_col, group_keys=False):
682682
col_mappings = {"identifier": [id_col], "time": [time_col], "outputs": [target_col], "inputs": input_cols}
683683

684684
for k in col_mappings:

examples/benchmarks/TFT/tft.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,15 @@
7878

7979

8080
def get_shifted_label(data_df, shifts=5, col_shift="LABEL0"):
81-
return data_df[[col_shift]].groupby("instrument").apply(lambda df: df.shift(shifts))
81+
return data_df[[col_shift]].groupby("instrument", group_keys=False).apply(lambda df: df.shift(shifts))
8282

8383

8484
def fill_test_na(test_df):
8585
test_df_res = test_df.copy()
8686
feature_cols = ~test_df_res.columns.str.contains("label", case=False)
87-
test_feature_fna = test_df_res.loc[:, feature_cols].groupby("datetime").apply(lambda df: df.fillna(df.mean()))
87+
test_feature_fna = (
88+
test_df_res.loc[:, feature_cols].groupby("datetime", group_keys=False).apply(lambda df: df.fillna(df.mean()))
89+
)
8890
test_df_res.loc[:, feature_cols] = test_feature_fna
8991
return test_df_res
9092

examples/benchmarks/TRA/src/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _create_ts_slices(index, seq_len):
2929
assert index.is_lexsorted(), "index should be sorted"
3030

3131
# number of dates for each code
32-
sample_count_by_codes = pd.Series(0, index=index).groupby(level=0).size().values
32+
sample_count_by_codes = pd.Series(0, index=index).groupby(level=0, group_keys=False).size().values
3333

3434
# start_index for each code
3535
start_index_of_codes = np.roll(np.cumsum(sample_count_by_codes), 1)

examples/highfreq/highfreq_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class DayLast(ElemOperator):
2525
def _load_internal(self, instrument, start_index, end_index, freq):
2626
_calendar = get_calendar_day(freq=freq)
2727
series = self.feature.load(instrument, start_index, end_index, freq)
28-
return series.groupby(_calendar[series.index]).transform("last")
28+
return series.groupby(_calendar[series.index], group_keys=False).transform("last")
2929

3030

3131
class FFillNan(ElemOperator):
@@ -44,7 +44,7 @@ class FFillNan(ElemOperator):
4444

4545
def _load_internal(self, instrument, start_index, end_index, freq):
4646
series = self.feature.load(instrument, start_index, end_index, freq)
47-
return series.fillna(method="ffill")
47+
return series.ffill()
4848

4949

5050
class BFillNan(ElemOperator):
@@ -63,7 +63,7 @@ class BFillNan(ElemOperator):
6363

6464
def _load_internal(self, instrument, start_index, end_index, freq):
6565
series = self.feature.load(instrument, start_index, end_index, freq)
66-
return series.fillna(method="bfill")
66+
return series.bfill()
6767

6868

6969
class Date(ElemOperator):

examples/rl_order_execution/scripts/gen_training_orders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> bool:
1919

2020
df["date"] = df["datetime"].dt.date.astype("datetime64")
2121
df = df.set_index(["instrument", "datetime", "date"])
22-
df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0)
22+
df = df.groupby("date", group_keys=False).take(range(start_idx, end_idx)).droplevel(level=0)
2323

24-
order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna())
24+
order_all = pd.DataFrame(df.groupby(level=(2, 0), group_keys=False).mean().dropna())
2525
order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"]
2626
order_all = order_all[order_all["amount"] > 0.0]
2727
order_all["order_type"] = 0

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ readme = {file = "README.md", content-type = "text/markdown"}
2626
dependencies = [
2727
"pyyaml",
2828
"numpy",
29-
"pandas",
29+
"pandas>=0.24",
3030
"mlflow",
3131
"filelock>=3.16.0",
3232
"redis",
@@ -67,10 +67,13 @@ lint = [
6767
"flake8",
6868
"nbqa",
6969
]
70+
# snowballstemmer, a dependency of sphinx, was released on 2025-05-08 with version 3.0.0,
71+
# which causes errors in the build process. So we've limited the version for now.
7072
docs = [
7173
"sphinx",
7274
"sphinx_rtd_theme",
7375
"readthedocs_sphinx_ext",
76+
"snowballstemmer<3.0",
7477
]
7578
package = [
7679
"twine",

qlib/backtest/high_performance_ds.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class PandasQuote(BaseQuote):
104104
def __init__(self, quote_df: pd.DataFrame, freq: str) -> None:
105105
super().__init__(quote_df=quote_df, freq=freq)
106106
quote_dict = {}
107-
for stock_id, stock_val in quote_df.groupby(level="instrument"):
107+
for stock_id, stock_val in quote_df.groupby(level="instrument", group_keys=False):
108108
quote_dict[stock_id] = stock_val.droplevel(level="instrument")
109109
self.data = quote_dict
110110

@@ -137,7 +137,7 @@ def __init__(self, quote_df: pd.DataFrame, freq: str, region: str = "cn") -> Non
137137
"""
138138
super().__init__(quote_df=quote_df, freq=freq)
139139
quote_dict = {}
140-
for stock_id, stock_val in quote_df.groupby(level="instrument"):
140+
for stock_id, stock_val in quote_df.groupby(level="instrument", group_keys=False):
141141
quote_dict[stock_id] = idd.MultiData(stock_val.droplevel(level="instrument"))
142142
quote_dict[stock_id].sort_index() # To support more flexible slicing, we must sort data first
143143
self.data = quote_dict

qlib/backtest/position.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last
311311
freq=freq,
312312
disk_cache=True,
313313
).dropna()
314-
price_dict = price_df.groupby(["instrument"]).tail(1).reset_index(level=1, drop=True)["$close"].to_dict()
314+
price_dict = price_df.groupby(["instrument"], group_keys=False).tail(1)["$close"].to_dict()
315315

316316
if len(price_dict) < len(stock_list):
317317
lack_stock = set(stock_list) - set(price_dict)

qlib/backtest/report.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,11 @@ def _cal_benchmark(benchmark_config: Optional[dict], freq: str) -> Optional[pd.S
114114
_temp_result, _ = get_higher_eq_freq_feature(_codes, fields, start_time, end_time, freq=freq)
115115
if len(_temp_result) == 0:
116116
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
117-
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
117+
return (
118+
_temp_result.groupby(level="datetime", group_keys=False)[_temp_result.columns.tolist()[0]]
119+
.mean()
120+
.fillna(0)
121+
)
118122

119123
def _sample_benchmark(
120124
self,

0 commit comments

Comments
 (0)