33import pandas as pd
44import numpy as np
55import yaml
6- from unittest .mock import MagicMock
6+ import json
7+ import sys
8+ import gc
9+ import signal
10+ from unittest .mock import MagicMock , patch , mock_open
11+ from io import StringIO
12+ from datetime import datetime
713
814@pytest .fixture (autouse = True )
915def mock_env (monkeypatch , tmp_path ):
1016 workspace = tmp_path / "MockWorkspace"
1117 workspace .mkdir ()
18+ (workspace / "config" ).mkdir ()
1219
1320 import sys
1421 monkeypatch .setattr (sys , 'argv' , ['script.py' ])
@@ -38,6 +45,12 @@ def test_zscore_norm(mock_env):
3845 normed = bfe .zscore_norm (series )
3946 assert np .isclose (normed .xs ("2020-01-01" , level = "datetime" ).mean (), 0 )
4047 assert np .isclose (normed .xs ("2020-01-01" , level = "datetime" ).std (), 1 )
48+
49+ # Zero std case (lines 138-139)
50+ zero_series = pd .Series ([1.0 , 1.0 , 1.0 ], index = pd .MultiIndex .from_arrays ([
51+ pd .to_datetime (["2020-01-01" ]* 3 ), ["A" , "B" , "C" ]
52+ ], names = ["datetime" , "instrument" ]))
53+ assert np .all (bfe .zscore_norm (zero_series ) == 0.0 )
4154
4255# ── load_combo_groups ────────────────────────────────────────────────────
4356def test_load_combo_groups (mock_env ):
@@ -65,6 +78,15 @@ def test_load_combo_groups_empty(mock_env):
6578 with pytest .raises (ValueError , match = "为空" ):
6679 bfe .load_combo_groups (str (cfg_path ), available_models = ["m1" ])
6780
81+ def test_load_combo_groups_warnings (mock_env , tmp_path ):
82+ bfe , _ = mock_env
83+ cfg_path = tmp_path / "groups_warn.yaml"
84+ cfg_path .write_text (yaml .dump ({"groups" : {"G1" : ["m_missing" ]}}))
85+ with patch ('builtins.print' ) as mock_print :
86+ groups = bfe .load_combo_groups (str (cfg_path ), ["m1" ])
87+ assert "G1" not in groups
88+ assert any ("组 [G1] 无有效模型,已跳过" in str (call ) for call in mock_print .call_args_list )
89+
6890# ── generate_grouped_combinations ────────────────────────────────────────
6991def test_generate_grouped_combinations (mock_env ):
7092 bfe , _ = mock_env
@@ -138,6 +160,10 @@ def test_extract_report_df(mock_env):
138160 assert bfe .extract_report_df (metrics_tuple ).equals (df )
139161
140162 assert bfe .extract_report_df (df ).equals (df )
163+
164+ # Case: metrics is a tuple, first element is a tuple (lines 368-370)
165+ metrics_nested = ((df , "extra" ),)
166+ assert bfe .extract_report_df (metrics_nested ) is df
141167
142168# ── load_config ──────────────────────────────────────────────────────────
143169def test_load_config (mock_env , tmp_path ):
@@ -154,6 +180,10 @@ def test_load_config(mock_env, tmp_path):
154180
155181 assert records ["models" ]["m1" ] == "r1"
156182 assert model_config ["TopK" ] == 20
183+
184+ # No file case (line 122)
185+ tr , mc = bfe .load_config ("non_existent_records.json" )
186+ assert tr == {"models" : {}, "experiment_name" : "unknown" }
157187
158188# ── correlation_analysis ─────────────────────────────────────────────────
159189def test_correlation_analysis (mock_env , tmp_path ):
@@ -185,8 +215,16 @@ def test_signal_handler_sets_shutdown(mock_env):
185215 bfe , _ = mock_env
186216 import quantpits .scripts .brute_force_ensemble as bfe_mod
187217 bfe_mod ._shutdown = False
188- bfe ._signal_handler (2 , None ) # SIGINT = 2
189- assert bfe_mod ._shutdown is True
218+
219+ with patch ('builtins.print' ) as mock_print :
220+ bfe ._signal_handler (2 , None ) # SIGINT = 2
221+ assert bfe_mod ._shutdown is True
222+ assert any ("安全退出" in str (call ) for call in mock_print .call_args_list )
223+
224+ # Second call exits (lines 81-82)
225+ with patch ('sys.exit' ) as mock_exit :
226+ bfe ._signal_handler (2 , None )
227+ mock_exit .assert_called_with (1 )
190228
191229# ── _append_results_to_csv ───────────────────────────────────────────────
192230def test_append_results_to_csv (mock_env , tmp_path ):
@@ -654,4 +692,167 @@ def test_main_full(mock_safeguard, mock_analyze, mock_bf, mock_corr, mock_load_p
654692 mock_bf .assert_called_once ()
655693 mock_analyze .assert_called_once ()
656694
695+ @patch ('qlib.backtest.account.Account' )
696+ @patch ('qlib.backtest.executor.SimulatorExecutor' )
697+ @patch ('qlib.backtest.utils.CommonInfrastructure' )
698+ @patch ('qlib.backtest.backtest_loop' )
699+ @patch ('quantpits.scripts.strategy.load_strategy_config' )
700+ @patch ('quantpits.scripts.strategy.get_backtest_config' )
701+ @patch ('quantpits.scripts.strategy.create_backtest_strategy' )
702+ @patch ('quantpits.scripts.analysis.portfolio_analyzer.PortfolioAnalyzer' )
703+ def test_run_single_backtest_non_datetime_index (mock_pa , mock_st_create , mock_bt_cfg , mock_st_cfg , mock_bt_loop , mock_infra , mock_executor , mock_account , mock_env ):
704+ bfe , _ = mock_env
705+ dates = ["2020-01-01" , "2020-01-02" ]
706+ report = pd .DataFrame ({
707+ "account" : [100000 , 101000 ],
708+ "bench" : [0.001 , 0.002 ]
709+ }, index = dates ) # String index
710+
711+ mock_st_cfg .return_value = {"benchmark" : "SH000300" }
712+ mock_bt_cfg .return_value = {"account" : 1000000 }
713+ mock_bt_loop .return_value = (report , None )
714+
715+ mock_pa_inst = MagicMock ()
716+ mock_pa_inst .calculate_traditional_metrics .return_value = {"CAGR" : 0.1 }
717+ mock_pa .return_value = mock_pa_inst
718+
719+ idx = pd .MultiIndex .from_product ([pd .to_datetime (dates ), ["A" ]], names = ["datetime" , "instrument" ])
720+ norm_df = pd .DataFrame ({"m1" : [0.5 , 0.6 ]}, index = idx )
721+
722+ with patch ('qlib.data.D' ) as mock_D_inst :
723+ mock_D_inst .calendar .return_value = pd .to_datetime (dates )
724+ res = bfe .run_single_backtest (["m1" ], norm_df , 1 , 0 , "SH000300" , "day" , MagicMock (), "2020-01-01" , "2020-01-02" )
725+ assert res ["Ann_Ret" ] == 0.1
726+
727+ def test_analyze_results_clustering_and_opt_fails (mock_env , tmp_path ):
728+ bfe , _ = mock_env
729+ out_dir = tmp_path / "complex_fail"
730+ out_dir .mkdir ()
731+ results_df = pd .DataFrame ({
732+ "models" : ["m1" , "m2" ],
733+ "n_models" : [1 , 1 ],
734+ "Ann_Excess" : [0.1 , 0.12 ],
735+ "Calmar" : [1.0 , 1.2 ],
736+ "Ann_Ret" : [0.15 , 0.16 ],
737+ "Max_DD" : [- 0.1 , - 0.1 ]
738+ })
739+ results_df ["diversity_bonus" ] = [0.01 , 0.02 ]
740+
741+ with patch ('qlib.workflow.R' ) as mock_R_inst :
742+ mock_R_inst .get_recorder .return_value .load_object .side_effect = Exception ("Load fail" )
743+ with patch ('builtins.print' ) as mock_print :
744+ bfe .analyze_results (results_df , pd .DataFrame (), pd .DataFrame (), {"experiment_name" : "exp" , "models" : {"m1" : "r1" }}, str (out_dir ), "test" )
745+ assert any ("[跳过] m1: Load fail" in str (call ) for call in mock_print .call_args_list )
746+
747+ @patch ('quantpits.scripts.brute_force_ensemble.run_single_backtest' )
748+ @patch ('quantpits.scripts.brute_force_ensemble.load_config' )
749+ def test_main_oos_validation_stage5 (mock_load_cfg , mock_run , mock_env , tmp_path ):
750+ bfe , _ = mock_env
751+ out_dir = tmp_path / "oos_stage5"
752+ out_dir .mkdir ()
753+
754+ res_df = pd .DataFrame ({"models" : ["m1" ], "Ann_Excess" : [0.1 ]})
755+ res_df .to_csv (out_dir / "brute_force_results_test.csv" , index = False )
756+
757+ idx = pd .MultiIndex .from_product ([pd .to_datetime (["2021-01-01" ]), ["A" ]], names = ["datetime" , "instrument" ])
758+ norm_df = pd .DataFrame ({"m1" : [0.5 ]}, index = idx )
759+
760+ mock_load_cfg .return_value = ({"anchor_date" : "test" , "experiment_name" : "exp" , "models" : {"m1" : "r1" }}, {"TopK" : 1 })
761+ mock_run .return_value = {"models" : "m1" , "Ann_Ret" : 0.1 , "Max_DD" : - 0.05 , "Ann_Excess" : 0.05 , "Calmar" : 2.0 }
762+
763+ with patch ('sys.argv' , ['script.py' , '--auto-test-top' , '1' , '--output-dir' , str (out_dir )]):
764+ with patch ('quantpits.scripts.brute_force_ensemble.load_predictions' , return_value = (norm_df , {})):
765+ with patch ('quantpits.scripts.brute_force_ensemble.split_is_oos_by_args' , return_value = (norm_df , norm_df )):
766+ with patch ('quantpits.scripts.brute_force_ensemble.brute_force_backtest' , return_value = res_df ):
767+ with patch ('quantpits.scripts.brute_force_ensemble.analyze_results' ):
768+ with patch ('qlib.backtest.exchange.Exchange' ):
769+ with patch ('builtins.print' ):
770+ bfe .main ()
771+ assert os .path .exists (out_dir / "oos_validation_test.csv" )
772+
773+ @patch ('qlib.backtest.exchange.Exchange' )
774+ def test_brute_force_backtest_grouped_and_no_pending (mock_exch , mock_env , tmp_path ):
775+ bfe , _ = mock_env
776+ out_dir = tmp_path / "bfe_out"
777+ out_dir .mkdir ()
778+ idx = pd .MultiIndex .from_product ([pd .to_datetime (["2020-01-01" ]), ["A" ]], names = ["datetime" , "instrument" ])
779+ norm_df = pd .DataFrame ({"m1" : [0.5 ], "m2" : [0.6 ]}, index = idx )
780+
781+ # 1. Test no pending tasks (lines 584-588)
782+ csv_path = out_dir / "brute_force_results_test.csv"
783+ pd .DataFrame ({"models" : ["m1" , "m2" ], "Ann_Excess" : [0.1 , 0.2 ]}).to_csv (csv_path , index = False )
784+
785+ with patch ('builtins.print' ) as mock_print :
786+ res = bfe .brute_force_backtest (norm_df , 1 , 0 , "BENCH" , "day" , 1 , 1 , str (out_dir ), "test" , resume = True )
787+ assert len (res ) == 2
788+ assert any ("所有组合已完成!" in str (call ) for call in mock_print .call_args_list )
789+
790+ @patch ('quantpits.scripts.brute_force_ensemble.run_single_backtest' , return_value = None )
791+ @patch ('qlib.backtest.exchange.Exchange' )
792+ def test_brute_force_backtest_no_results (mock_exch , mock_run , mock_env , tmp_path ):
793+ bfe , _ = mock_env
794+ out_dir = tmp_path / "no_res"
795+ out_dir .mkdir ()
796+ idx = pd .MultiIndex .from_product ([pd .to_datetime (["2020-01-01" ]), ["A" ]], names = ["datetime" , "instrument" ])
797+ norm_df = pd .DataFrame ({"m1" : [0.5 ]}, index = idx )
798+
799+ with patch ('builtins.print' ) as mock_print :
800+ bfe .brute_force_backtest (norm_df , 1 , 0 , "BENCH" , "day" , 1 , 1 , str (out_dir ), "no_res" )
801+ assert any ("警告: 无有效回测结果" in str (call ) for call in mock_print .call_args_list )
802+
803+ @patch ('qlib.workflow.R' )
804+ def test_analyze_results_plot_fails (mock_R , mock_env , tmp_path ):
805+ bfe , _ = mock_env
806+ out_dir = tmp_path / "plot_fail"
807+ out_dir .mkdir ()
808+ results_df = pd .DataFrame ({"models" : ["m1" ], "n_models" : [1 ], "Ann_Excess" : [0.1 ], "Calmar" : [1.0 ], "Ann_Ret" : [0.15 ], "Max_DD" : [- 0.1 ]})
809+ results_df ["diversity_bonus" ] = [0.0 ]
810+ mock_R .get_recorder .return_value .load_object .side_effect = Exception ("Skip cluster" )
811+
812+ with patch ('matplotlib.pyplot.savefig' , side_effect = Exception ("Save fail" )):
813+ with patch ('builtins.print' ) as mock_print :
814+ bfe .analyze_results (results_df , pd .DataFrame (), pd .DataFrame (), {"experiment_name" : "exp" , "models" : {"m1" : "r1" }}, str (out_dir ), "test" )
815+ assert any ("归因图绘制失败: Save fail" in str (call ) for call in mock_print .call_args_list )
816+
817+ def test_main_empty_is_exit (mock_env ):
818+ bfe , _ = mock_env
819+ with patch ('quantpits.scripts.brute_force_ensemble.load_predictions' , return_value = (pd .DataFrame (), {})):
820+ with patch ('quantpits.scripts.brute_force_ensemble.split_is_oos_by_args' , return_value = (pd .DataFrame (), pd .DataFrame ())):
821+ with patch ('builtins.print' ):
822+ with pytest .raises (SystemExit ) as e :
823+ bfe .main ()
824+ assert e .value .code == 1
825+
826+ def test_main_analysis_only_glob (mock_env , tmp_path ):
827+ bfe , _ = mock_env
828+ out_dir = tmp_path / "glob_test"
829+ out_dir .mkdir ()
830+ pd .DataFrame ({"models" : ["m1" ], "Ann_Excess" : [0.1 ]}).to_csv (out_dir / "brute_force_results_2020.csv" , index = False )
831+ idx = pd .MultiIndex .from_product ([pd .to_datetime (["2020-01-01" ]), ["A" ]], names = ["datetime" , "instrument" ])
832+ norm_df = pd .DataFrame ({"m1" : [0.5 ]}, index = idx )
833+
834+ with patch ('sys.argv' , ['script.py' , '--analysis-only' , '--output-dir' , str (out_dir )]):
835+ with patch ('quantpits.scripts.brute_force_ensemble.load_predictions' , return_value = (norm_df , {})):
836+ with patch ('quantpits.scripts.brute_force_ensemble.split_is_oos_by_args' , return_value = (norm_df , pd .DataFrame ())):
837+ with patch ('builtins.print' ) as mock_print :
838+ with patch ('quantpits.scripts.brute_force_ensemble.analyze_results' ):
839+ bfe .main ()
840+ assert any ("使用最新结果文件" in str (call ) for call in mock_print .call_args_list )
841+
842+ @patch ('quantpits.scripts.brute_force_ensemble.load_config' )
843+ def test_main_oos_validation_no_data (mock_load_cfg , mock_env , tmp_path ):
844+ bfe , _ = mock_env
845+ out_dir = tmp_path / "oos_no_data"
846+ out_dir .mkdir ()
847+ mock_load_cfg .return_value = ({"anchor_date" : "test" , "experiment_name" : "exp" , "models" : {"m1" : "r1" }}, {"TopK" : 1 })
848+ norm_df = pd .DataFrame ({"m1" : [0.5 ]}, index = pd .MultiIndex .from_product ([pd .to_datetime (["2021-01-01" ]), ["A" ]], names = ["datetime" , "instrument" ]))
849+ with patch ('sys.argv' , ['script.py' , '--auto-test-top' , '1' , '--output-dir' , str (out_dir )]):
850+ with patch ('quantpits.scripts.brute_force_ensemble.load_predictions' , return_value = (norm_df , {})):
851+ with patch ('quantpits.scripts.brute_force_ensemble.split_is_oos_by_args' , return_value = (norm_df , pd .DataFrame ())):
852+ with patch ('quantpits.scripts.brute_force_ensemble.brute_force_backtest' , return_value = pd .DataFrame ({"models" : ["m1" ], "Ann_Excess" : [0.1 ]})):
853+ with patch ('quantpits.scripts.brute_force_ensemble.analyze_results' ):
854+ with patch ('builtins.print' ) as mock_print :
855+ bfe .main ()
856+ assert any ("无法进行 OOS 验证:无 OOS 数据" in str (call ) for call in mock_print .call_args_list )
857+
657858
0 commit comments