Skip to content

Commit 078c1f6

Browse files
author
Tonny@Home
committed
test: enhance ensemble related test
1 parent 59e2977 commit 078c1f6

File tree

4 files changed

+818
-5
lines changed

4 files changed

+818
-5
lines changed

.github/workflows/pytest.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ jobs:
1010
test:
1111
runs-on: ubuntu-latest
1212
strategy:
13+
fail-fast: false
1314
matrix:
14-
python-version: ["3.12"]
15+
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
1516

1617
steps:
1718
- uses: actions/checkout@v4

tests/quantpits/scripts/test_brute_force_ensemble.py

Lines changed: 204 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,19 @@
33
import pandas as pd
44
import numpy as np
55
import yaml
6-
from unittest.mock import MagicMock
6+
import json
7+
import sys
8+
import gc
9+
import signal
10+
from unittest.mock import MagicMock, patch, mock_open
11+
from io import StringIO
12+
from datetime import datetime
713

814
@pytest.fixture(autouse=True)
915
def mock_env(monkeypatch, tmp_path):
1016
workspace = tmp_path / "MockWorkspace"
1117
workspace.mkdir()
18+
(workspace / "config").mkdir()
1219

1320
import sys
1421
monkeypatch.setattr(sys, 'argv', ['script.py'])
@@ -38,6 +45,12 @@ def test_zscore_norm(mock_env):
3845
normed = bfe.zscore_norm(series)
3946
assert np.isclose(normed.xs("2020-01-01", level="datetime").mean(), 0)
4047
assert np.isclose(normed.xs("2020-01-01", level="datetime").std(), 1)
48+
49+
# Zero std case (lines 138-139)
50+
zero_series = pd.Series([1.0, 1.0, 1.0], index=pd.MultiIndex.from_arrays([
51+
pd.to_datetime(["2020-01-01"]*3), ["A", "B", "C"]
52+
], names=["datetime", "instrument"]))
53+
assert np.all(bfe.zscore_norm(zero_series) == 0.0)
4154

4255
# ── load_combo_groups ────────────────────────────────────────────────────
4356
def test_load_combo_groups(mock_env):
@@ -65,6 +78,15 @@ def test_load_combo_groups_empty(mock_env):
6578
with pytest.raises(ValueError, match="为空"):
6679
bfe.load_combo_groups(str(cfg_path), available_models=["m1"])
6780

81+
def test_load_combo_groups_warnings(mock_env, tmp_path):
82+
bfe, _ = mock_env
83+
cfg_path = tmp_path / "groups_warn.yaml"
84+
cfg_path.write_text(yaml.dump({"groups": {"G1": ["m_missing"]}}))
85+
with patch('builtins.print') as mock_print:
86+
groups = bfe.load_combo_groups(str(cfg_path), ["m1"])
87+
assert "G1" not in groups
88+
assert any("组 [G1] 无有效模型,已跳过" in str(call) for call in mock_print.call_args_list)
89+
6890
# ── generate_grouped_combinations ────────────────────────────────────────
6991
def test_generate_grouped_combinations(mock_env):
7092
bfe, _ = mock_env
@@ -138,6 +160,10 @@ def test_extract_report_df(mock_env):
138160
assert bfe.extract_report_df(metrics_tuple).equals(df)
139161

140162
assert bfe.extract_report_df(df).equals(df)
163+
164+
# Case: metrics is a tuple, first element is a tuple (lines 368-370)
165+
metrics_nested = ((df, "extra"),)
166+
assert bfe.extract_report_df(metrics_nested) is df
141167

142168
# ── load_config ──────────────────────────────────────────────────────────
143169
def test_load_config(mock_env, tmp_path):
@@ -154,6 +180,10 @@ def test_load_config(mock_env, tmp_path):
154180

155181
assert records["models"]["m1"] == "r1"
156182
assert model_config["TopK"] == 20
183+
184+
# No file case (line 122)
185+
tr, mc = bfe.load_config("non_existent_records.json")
186+
assert tr == {"models": {}, "experiment_name": "unknown"}
157187

158188
# ── correlation_analysis ─────────────────────────────────────────────────
159189
def test_correlation_analysis(mock_env, tmp_path):
@@ -185,8 +215,16 @@ def test_signal_handler_sets_shutdown(mock_env):
185215
bfe, _ = mock_env
186216
import quantpits.scripts.brute_force_ensemble as bfe_mod
187217
bfe_mod._shutdown = False
188-
bfe._signal_handler(2, None) # SIGINT = 2
189-
assert bfe_mod._shutdown is True
218+
219+
with patch('builtins.print') as mock_print:
220+
bfe._signal_handler(2, None) # SIGINT = 2
221+
assert bfe_mod._shutdown is True
222+
assert any("安全退出" in str(call) for call in mock_print.call_args_list)
223+
224+
# Second call exits (lines 81-82)
225+
with patch('sys.exit') as mock_exit:
226+
bfe._signal_handler(2, None)
227+
mock_exit.assert_called_with(1)
190228

191229
# ── _append_results_to_csv ───────────────────────────────────────────────
192230
def test_append_results_to_csv(mock_env, tmp_path):
@@ -654,4 +692,167 @@ def test_main_full(mock_safeguard, mock_analyze, mock_bf, mock_corr, mock_load_p
654692
mock_bf.assert_called_once()
655693
mock_analyze.assert_called_once()
656694

695+
@patch('qlib.backtest.account.Account')
696+
@patch('qlib.backtest.executor.SimulatorExecutor')
697+
@patch('qlib.backtest.utils.CommonInfrastructure')
698+
@patch('qlib.backtest.backtest_loop')
699+
@patch('quantpits.scripts.strategy.load_strategy_config')
700+
@patch('quantpits.scripts.strategy.get_backtest_config')
701+
@patch('quantpits.scripts.strategy.create_backtest_strategy')
702+
@patch('quantpits.scripts.analysis.portfolio_analyzer.PortfolioAnalyzer')
703+
def test_run_single_backtest_non_datetime_index(mock_pa, mock_st_create, mock_bt_cfg, mock_st_cfg, mock_bt_loop, mock_infra, mock_executor, mock_account, mock_env):
704+
bfe, _ = mock_env
705+
dates = ["2020-01-01", "2020-01-02"]
706+
report = pd.DataFrame({
707+
"account": [100000, 101000],
708+
"bench": [0.001, 0.002]
709+
}, index=dates) # String index
710+
711+
mock_st_cfg.return_value = {"benchmark": "SH000300"}
712+
mock_bt_cfg.return_value = {"account": 1000000}
713+
mock_bt_loop.return_value = (report, None)
714+
715+
mock_pa_inst = MagicMock()
716+
mock_pa_inst.calculate_traditional_metrics.return_value = {"CAGR": 0.1}
717+
mock_pa.return_value = mock_pa_inst
718+
719+
idx = pd.MultiIndex.from_product([pd.to_datetime(dates), ["A"]], names=["datetime", "instrument"])
720+
norm_df = pd.DataFrame({"m1": [0.5, 0.6]}, index=idx)
721+
722+
with patch('qlib.data.D') as mock_D_inst:
723+
mock_D_inst.calendar.return_value = pd.to_datetime(dates)
724+
res = bfe.run_single_backtest(["m1"], norm_df, 1, 0, "SH000300", "day", MagicMock(), "2020-01-01", "2020-01-02")
725+
assert res["Ann_Ret"] == 0.1
726+
727+
def test_analyze_results_clustering_and_opt_fails(mock_env, tmp_path):
728+
bfe, _ = mock_env
729+
out_dir = tmp_path / "complex_fail"
730+
out_dir.mkdir()
731+
results_df = pd.DataFrame({
732+
"models": ["m1", "m2"],
733+
"n_models": [1, 1],
734+
"Ann_Excess": [0.1, 0.12],
735+
"Calmar": [1.0, 1.2],
736+
"Ann_Ret": [0.15, 0.16],
737+
"Max_DD": [-0.1, -0.1]
738+
})
739+
results_df["diversity_bonus"] = [0.01, 0.02]
740+
741+
with patch('qlib.workflow.R') as mock_R_inst:
742+
mock_R_inst.get_recorder.return_value.load_object.side_effect = Exception("Load fail")
743+
with patch('builtins.print') as mock_print:
744+
bfe.analyze_results(results_df, pd.DataFrame(), pd.DataFrame(), {"experiment_name": "exp", "models": {"m1": "r1"}}, str(out_dir), "test")
745+
assert any("[跳过] m1: Load fail" in str(call) for call in mock_print.call_args_list)
746+
747+
@patch('quantpits.scripts.brute_force_ensemble.run_single_backtest')
748+
@patch('quantpits.scripts.brute_force_ensemble.load_config')
749+
def test_main_oos_validation_stage5(mock_load_cfg, mock_run, mock_env, tmp_path):
750+
bfe, _ = mock_env
751+
out_dir = tmp_path / "oos_stage5"
752+
out_dir.mkdir()
753+
754+
res_df = pd.DataFrame({"models": ["m1"], "Ann_Excess": [0.1]})
755+
res_df.to_csv(out_dir / "brute_force_results_test.csv", index=False)
756+
757+
idx = pd.MultiIndex.from_product([pd.to_datetime(["2021-01-01"]), ["A"]], names=["datetime", "instrument"])
758+
norm_df = pd.DataFrame({"m1": [0.5]}, index=idx)
759+
760+
mock_load_cfg.return_value = ({"anchor_date": "test", "experiment_name": "exp", "models": {"m1": "r1"}}, {"TopK": 1})
761+
mock_run.return_value = {"models": "m1", "Ann_Ret": 0.1, "Max_DD": -0.05, "Ann_Excess": 0.05, "Calmar": 2.0}
762+
763+
with patch('sys.argv', ['script.py', '--auto-test-top', '1', '--output-dir', str(out_dir)]):
764+
with patch('quantpits.scripts.brute_force_ensemble.load_predictions', return_value=(norm_df, {})):
765+
with patch('quantpits.scripts.brute_force_ensemble.split_is_oos_by_args', return_value=(norm_df, norm_df)):
766+
with patch('quantpits.scripts.brute_force_ensemble.brute_force_backtest', return_value=res_df):
767+
with patch('quantpits.scripts.brute_force_ensemble.analyze_results'):
768+
with patch('qlib.backtest.exchange.Exchange'):
769+
with patch('builtins.print'):
770+
bfe.main()
771+
assert os.path.exists(out_dir / "oos_validation_test.csv")
772+
773+
@patch('qlib.backtest.exchange.Exchange')
774+
def test_brute_force_backtest_grouped_and_no_pending(mock_exch, mock_env, tmp_path):
775+
bfe, _ = mock_env
776+
out_dir = tmp_path / "bfe_out"
777+
out_dir.mkdir()
778+
idx = pd.MultiIndex.from_product([pd.to_datetime(["2020-01-01"]), ["A"]], names=["datetime", "instrument"])
779+
norm_df = pd.DataFrame({"m1": [0.5], "m2": [0.6]}, index=idx)
780+
781+
# 1. Test no pending tasks (lines 584-588)
782+
csv_path = out_dir / "brute_force_results_test.csv"
783+
pd.DataFrame({"models": ["m1", "m2"], "Ann_Excess": [0.1, 0.2]}).to_csv(csv_path, index=False)
784+
785+
with patch('builtins.print') as mock_print:
786+
res = bfe.brute_force_backtest(norm_df, 1, 0, "BENCH", "day", 1, 1, str(out_dir), "test", resume=True)
787+
assert len(res) == 2
788+
assert any("所有组合已完成!" in str(call) for call in mock_print.call_args_list)
789+
790+
@patch('quantpits.scripts.brute_force_ensemble.run_single_backtest', return_value=None)
791+
@patch('qlib.backtest.exchange.Exchange')
792+
def test_brute_force_backtest_no_results(mock_exch, mock_run, mock_env, tmp_path):
793+
bfe, _ = mock_env
794+
out_dir = tmp_path / "no_res"
795+
out_dir.mkdir()
796+
idx = pd.MultiIndex.from_product([pd.to_datetime(["2020-01-01"]), ["A"]], names=["datetime", "instrument"])
797+
norm_df = pd.DataFrame({"m1": [0.5]}, index=idx)
798+
799+
with patch('builtins.print') as mock_print:
800+
bfe.brute_force_backtest(norm_df, 1, 0, "BENCH", "day", 1, 1, str(out_dir), "no_res")
801+
assert any("警告: 无有效回测结果" in str(call) for call in mock_print.call_args_list)
802+
803+
@patch('qlib.workflow.R')
804+
def test_analyze_results_plot_fails(mock_R, mock_env, tmp_path):
805+
bfe, _ = mock_env
806+
out_dir = tmp_path / "plot_fail"
807+
out_dir.mkdir()
808+
results_df = pd.DataFrame({"models": ["m1"], "n_models": [1], "Ann_Excess": [0.1], "Calmar": [1.0], "Ann_Ret": [0.15], "Max_DD": [-0.1]})
809+
results_df["diversity_bonus"] = [0.0]
810+
mock_R.get_recorder.return_value.load_object.side_effect = Exception("Skip cluster")
811+
812+
with patch('matplotlib.pyplot.savefig', side_effect=Exception("Save fail")):
813+
with patch('builtins.print') as mock_print:
814+
bfe.analyze_results(results_df, pd.DataFrame(), pd.DataFrame(), {"experiment_name": "exp", "models": {"m1": "r1"}}, str(out_dir), "test")
815+
assert any("归因图绘制失败: Save fail" in str(call) for call in mock_print.call_args_list)
816+
817+
def test_main_empty_is_exit(mock_env):
818+
bfe, _ = mock_env
819+
with patch('quantpits.scripts.brute_force_ensemble.load_predictions', return_value=(pd.DataFrame(), {})):
820+
with patch('quantpits.scripts.brute_force_ensemble.split_is_oos_by_args', return_value=(pd.DataFrame(), pd.DataFrame())):
821+
with patch('builtins.print'):
822+
with pytest.raises(SystemExit) as e:
823+
bfe.main()
824+
assert e.value.code == 1
825+
826+
def test_main_analysis_only_glob(mock_env, tmp_path):
827+
bfe, _ = mock_env
828+
out_dir = tmp_path / "glob_test"
829+
out_dir.mkdir()
830+
pd.DataFrame({"models": ["m1"], "Ann_Excess": [0.1]}).to_csv(out_dir / "brute_force_results_2020.csv", index=False)
831+
idx = pd.MultiIndex.from_product([pd.to_datetime(["2020-01-01"]), ["A"]], names=["datetime", "instrument"])
832+
norm_df = pd.DataFrame({"m1": [0.5]}, index=idx)
833+
834+
with patch('sys.argv', ['script.py', '--analysis-only', '--output-dir', str(out_dir)]):
835+
with patch('quantpits.scripts.brute_force_ensemble.load_predictions', return_value=(norm_df, {})):
836+
with patch('quantpits.scripts.brute_force_ensemble.split_is_oos_by_args', return_value=(norm_df, pd.DataFrame())):
837+
with patch('builtins.print') as mock_print:
838+
with patch('quantpits.scripts.brute_force_ensemble.analyze_results'):
839+
bfe.main()
840+
assert any("使用最新结果文件" in str(call) for call in mock_print.call_args_list)
841+
842+
@patch('quantpits.scripts.brute_force_ensemble.load_config')
843+
def test_main_oos_validation_no_data(mock_load_cfg, mock_env, tmp_path):
844+
bfe, _ = mock_env
845+
out_dir = tmp_path / "oos_no_data"
846+
out_dir.mkdir()
847+
mock_load_cfg.return_value = ({"anchor_date": "test", "experiment_name": "exp", "models": {"m1": "r1"}}, {"TopK": 1})
848+
norm_df = pd.DataFrame({"m1": [0.5]}, index=pd.MultiIndex.from_product([pd.to_datetime(["2021-01-01"]), ["A"]], names=["datetime", "instrument"]))
849+
with patch('sys.argv', ['script.py', '--auto-test-top', '1', '--output-dir', str(out_dir)]):
850+
with patch('quantpits.scripts.brute_force_ensemble.load_predictions', return_value=(norm_df, {})):
851+
with patch('quantpits.scripts.brute_force_ensemble.split_is_oos_by_args', return_value=(norm_df, pd.DataFrame())):
852+
with patch('quantpits.scripts.brute_force_ensemble.brute_force_backtest', return_value=pd.DataFrame({"models": ["m1"], "Ann_Excess": [0.1]})):
853+
with patch('quantpits.scripts.brute_force_ensemble.analyze_results'):
854+
with patch('builtins.print') as mock_print:
855+
bfe.main()
856+
assert any("无法进行 OOS 验证:无 OOS 数据" in str(call) for call in mock_print.call_args_list)
857+
657858

0 commit comments

Comments
 (0)