@@ -585,6 +585,12 @@ def parse_args():
585585 help = '仅使用最新模型预测,不训练' )
586586 mode .add_argument ('--resume' , action = 'store_true' ,
587587 help = '从断点恢复训练' )
588+ mode .add_argument ('--merge' , action = 'store_true' ,
589+ help = '合并冷启动:在已有 rolling 状态基础上追加新模型' )
590+ mode .add_argument ('--backtest' , action = 'store_true' ,
591+ help = '训练拼接完成后,对产出的合成预测进行全量回测' )
592+ mode .add_argument ('--backtest-only' , action = 'store_true' ,
593+ help = '仅对 latest_rolling_records.json 中的模型进行回测 (跳过训练预测)' )
588594
589595 select = parser .add_argument_group ('模型选择' )
590596 select .add_argument ('--models' , type = str ,
@@ -712,10 +718,14 @@ def run_cold_start(args, targets, rolling_cfg):
712718
713719 # 初始化状态
714720 state = RollingState ()
715- if not args . resume :
721+ if getattr ( args , 'resume' , False ) is not True and getattr ( args , 'merge' , False ) is not True :
716722 state .init_run (rolling_cfg , anchor_date , len (windows ))
717723 else :
718- print ("⏩ Resume 模式:跳过已完成的 window×model" )
724+ if not state .anchor_date :
725+ print ("❌ 无 rolling 状态可恢复,将新建状态" )
726+ state .init_run (rolling_cfg , anchor_date , len (windows ))
727+ else :
728+ print (f"⏩ { 'Merge' if getattr (args , 'merge' , False ) is True else 'Resume' } 模式:跳过已完成窗格" )
719729
720730 rolling_exp_name = f"Rolling_Windows_{ freq } "
721731 combined_exp_name = f"Rolling_Combined_{ freq } "
@@ -753,13 +763,23 @@ def run_cold_start(args, targets, rolling_cfg):
753763 print (f"❌ { model_name } W{ widx } 训练失败: { result .get ('error' , 'Unknown' )} " )
754764
755765 # 拼接预测
756- model_names = list (targets .keys ())
766+ if getattr (args , 'merge' , False ) is True :
767+ completed = state .get_all_completed_windows ()
768+ all_models = set (targets .keys ())
769+ for win , models in completed .items ():
770+ all_models .update (models .keys ())
771+ model_names = list (all_models )
772+ else :
773+ model_names = list (targets .keys ())
774+
757775 combined_records = concatenate_rolling_predictions (
758776 state , model_names , rolling_exp_name , combined_exp_name , anchor_date
759777 )
760778
761779 if combined_records :
762780 save_rolling_records (combined_records , combined_exp_name , anchor_date )
781+ if getattr (args , 'backtest' , False ) is True :
782+ run_combined_backtest (model_names , combined_records , combined_exp_name , params_base )
763783
764784 # 完成
765785 print (f"\n { '=' * 60 } " )
@@ -843,6 +863,8 @@ def run_daily(args, targets, rolling_cfg):
843863 )
844864 if combined_records :
845865 save_rolling_records (combined_records , combined_exp_name , anchor_date )
866+ if getattr (args , 'backtest' , False ) is True :
867+ run_combined_backtest (model_names , combined_records , combined_exp_name , params_base )
846868
847869 print (f"\n ✅ Rolling 滚动更新完成 (新训练 { len (new_windows )} 个 windows)" )
848870
@@ -856,6 +878,193 @@ def run_daily(args, targets, rolling_cfg):
856878 )
857879
858880
881+ def run_combined_backtest (model_names , combined_records , combined_exp_name , params_base ):
882+ """
883+ 对合并后的预测执行回测,并将回测结果的 port_analysis 等指标保存追加回相应的记录。
884+ """
885+ from qlib .workflow import R
886+ from qlib .backtest import backtest
887+ from qlib .backtest .executor import SimulatorExecutor
888+ import strategy
889+ import numpy as np
890+ import warnings
891+
892+ print (f"\n { '=' * 60 } " )
893+ print ("📈 运行 Rolling 合并预测的回测" )
894+ print (f"{ '=' * 60 } " )
895+
896+ st_config = strategy .load_strategy_config ()
897+ bt_config = strategy .get_backtest_config (st_config )
898+
899+ for model_name in model_names :
900+ if model_name not in combined_records :
901+ continue
902+
903+ record_id = combined_records [model_name ]
904+ print (f"\n [{ model_name } ] 提取合并预测以进行回测 (Record: { record_id } )..." )
905+
906+ try :
907+ rec = R .get_recorder (recorder_id = record_id , experiment_name = combined_exp_name )
908+ pred = rec .load_object ("pred.pkl" )
909+
910+ if pred is None or pred .empty :
911+ print (f" [{ model_name } ] 预测为空,跳过回测。" )
912+ continue
913+
914+ bt_start = str (pred .index .get_level_values (0 ).min ().date ())
915+ bt_end = str (pred .index .get_level_values (0 ).max ().date ())
916+
917+ print (f" [{ model_name } ] Backtest Range: { bt_start } ~ { bt_end } " )
918+
919+ # Create Strategy
920+ strategy_inst = strategy .create_backtest_strategy (pred , st_config )
921+
922+ # Create Executor
923+ executor_obj = SimulatorExecutor (
924+ time_per_step = params_base ['freq' ],
925+ generate_portfolio_metrics = True ,
926+ verbose = False
927+ )
928+
929+ print (f" [{ model_name } ] 执行回测..." )
930+ with np .errstate (divide = 'ignore' , invalid = 'ignore' ), warnings .catch_warnings ():
931+ warnings .simplefilter ("ignore" , RuntimeWarning )
932+ raw_portfolio_metrics , raw_indicators = backtest (
933+ executor = executor_obj ,
934+ strategy = strategy_inst ,
935+ start_time = bt_start ,
936+ end_time = bt_end ,
937+ account = bt_config ['account' ],
938+ benchmark = params_base ['benchmark' ],
939+ exchange_kwargs = bt_config ['exchange_kwargs' ]
940+ )
941+
942+ # Use PortfolioAnalyzer to get traditional metrics
943+ from quantpits .scripts .analysis .portfolio_analyzer import PortfolioAnalyzer
944+ import pandas as pd
945+ from qlib .data import D
946+
947+ def extract_report_df (metrics ):
948+ if isinstance (metrics , dict ):
949+ val = list (metrics .values ())[0 ]
950+ return val [0 ] if isinstance (val , tuple ) else val
951+ elif isinstance (metrics , tuple ):
952+ first = metrics [0 ]
953+ if isinstance (first , pd .DataFrame ):
954+ return first
955+ elif isinstance (first , tuple ) and len (first ) >= 1 :
956+ return first [0 ]
957+ return metrics
958+ return metrics
959+
960+ report_df = extract_report_df (raw_portfolio_metrics )
961+ if report_df is None or report_df .empty :
962+ print (f" [{ model_name } ] 提取回测结果失败。" )
963+ continue
964+
965+ # Format report DataFrame
966+ da_df = pd .DataFrame (index = report_df .index )
967+ da_df ['收盘价值' ] = report_df ['account' ]
968+ da_df [params_base ['benchmark' ]] = (1 + report_df ['bench' ]).cumprod ()
969+ if not isinstance (da_df .index , pd .DatetimeIndex ):
970+ da_df .index = pd .to_datetime (da_df .index )
971+
972+ bt_start_dt = da_df .index .min ()
973+ bt_end_dt = da_df .index .max ()
974+ daily_dates = D .calendar (start_time = bt_start_dt , end_time = bt_end_dt , freq = 'day' )
975+ da_df = da_df .reindex (daily_dates , method = 'ffill' ).dropna (subset = ['收盘价值' ])
976+ da_df = da_df .reset_index ().rename (columns = {'index' : '成交日期' , 'datetime' : '成交日期' })
977+
978+ pa = PortfolioAnalyzer (
979+ daily_amount_df = da_df ,
980+ trade_log_df = pd .DataFrame (),
981+ holding_log_df = pd .DataFrame (),
982+ benchmark_col = params_base ['benchmark' ],
983+ freq = params_base ['freq' ]
984+ )
985+ metrics = pa .calculate_traditional_metrics ()
986+
987+ ann_ret = metrics .get ('CAGR' , 0 )
988+ max_dd = metrics .get ('Max_Drawdown' , 0 )
989+ excess = metrics .get ('Excess_Return_CAGR' , 0 )
990+ ir = metrics .get ('Information_Ratio' , 0 )
991+ calmar = metrics .get ('Calmar' , 0 )
992+
993+ print (f" [{ model_name } ] 回测完成! Ann_Ret: { ann_ret :.2%} , Excess: { excess :.2%} , Max_DD: { max_dd :.2%} , IR: { ir :.3f} " )
994+
995+ # Save objects back to the same recorder
996+ # By calling methods directly on the existing `rec` object
997+ try :
998+ rec .log_metrics (
999+ Ann_Ret = ann_ret ,
1000+ Max_DD = max_dd ,
1001+ Excess_Return = excess ,
1002+ Information_Ratio = ir ,
1003+ Calmar = calmar
1004+ )
1005+
1006+ # 严格按照 Qlib PortAnaRecord 的保存格式,把报告分离并保存到 portfolio_analysis 子目录下
1007+ port_ana_objs = {}
1008+ if isinstance (raw_portfolio_metrics , dict ):
1009+ for freq_key , metrics_tuple in raw_portfolio_metrics .items ():
1010+ if isinstance (metrics_tuple , tuple ) and len (metrics_tuple ) >= 2 :
1011+ port_ana_objs [f"report_normal_{ freq_key } .pkl" ] = metrics_tuple [0 ]
1012+ port_ana_objs [f"positions_normal_{ freq_key } .pkl" ] = metrics_tuple [1 ]
1013+ elif isinstance (metrics_tuple , tuple ) and len (metrics_tuple ) == 1 :
1014+ port_ana_objs [f"report_normal_{ freq_key } .pkl" ] = metrics_tuple [0 ]
1015+
1016+ if port_ana_objs :
1017+ rec .save_objects (artifact_path = "portfolio_analysis" , ** port_ana_objs )
1018+
1019+ # 指标分析保存到 sig_analysis 子目录(依照 SigAnaRecord)或者根目录
1020+ rec .save_objects (artifact_path = "sig_analysis" , ** {
1021+ f"indicator_analysis_{ params_base ['freq' ]} .pkl" : raw_indicators
1022+ })
1023+ except Exception as log_e :
1024+ print (f" [{ model_name } ] MLflow 记录失败,可能已存在同名 metric: { log_e } " )
1025+
1026+ except Exception as e :
1027+ print (f" [{ model_name } ] 回测过程失败: { e } " )
1028+ import traceback
1029+ traceback .print_exc ()
1030+
1031+
1032+ def run_backtest_only (args , targets ):
1033+ """仅回测模式:读取 latest_rolling_records.json 直接运行"""
1034+ import os , json
1035+ from train_utils import ROLLING_RECORD_FILE
1036+ env .init_qlib ()
1037+ params_base = get_base_params ()
1038+
1039+ if os .path .exists (ROLLING_RECORD_FILE ):
1040+ with open (ROLLING_RECORD_FILE , 'r' ) as f :
1041+ records = json .load (f )
1042+ else :
1043+ records = None
1044+
1045+ if not records or "models" not in records :
1046+ print ("❌ 找不到有效的 latest_rolling_records.json 或内容为空。" )
1047+ return
1048+
1049+ combined_exp_name = records .get ("experiment_name" )
1050+ if not combined_exp_name :
1051+ freq = params_base ['freq' ].upper ()
1052+ combined_exp_name = f"Rolling_Combined_{ freq } "
1053+
1054+ combined_records = records ["models" ]
1055+ model_names = []
1056+
1057+ for m in targets .keys ():
1058+ if m in combined_records :
1059+ model_names .append (m )
1060+
1061+ if not model_names :
1062+ print ("❌ 选定的模型中没有找到历史滚动预测记录。" )
1063+ return
1064+
1065+ run_combined_backtest (model_names , combined_records , combined_exp_name , params_base )
1066+
1067+
8591068def run_predict_only (args , targets , rolling_cfg ):
8601069 """仅预测模式"""
8611070 env .init_qlib ()
@@ -909,31 +1118,33 @@ def main():
9091118 args .tag , args .all_enabled
9101119 ])
9111120
912- if not has_selection and not args .resume :
1121+ if getattr (args , 'resume' , False ) is True or getattr (args , 'merge' , False ) is True or getattr (args , 'backtest_only' , False ) is True :
1122+ if not has_selection :
1123+ args .all_enabled = True
1124+ has_selection = True
1125+
1126+ if not has_selection :
9131127 print ("❌ 请指定要训练的模型" )
9141128 print (" 使用 --models, --algorithm, --dataset, --tag, 或 --all-enabled" )
9151129 return
9161130
917- if args .resume :
918- # Resume 模式: 从状态中恢复
1131+ if getattr (args , 'resume' , False ) is True or getattr (args , 'merge' , False ) is True :
9191132 state = RollingState ()
920- if not state .anchor_date :
1133+ if not state .anchor_date and getattr ( args , 'resume' , False ) is True :
9211134 print ("❌ 无 rolling 状态可恢复" )
9221135 return
9231136
924- if not has_selection :
925- # 没有指定模型,使用 --all-enabled
926- args .all_enabled = True
927-
9281137 targets = resolve_target_models (args )
9291138 if targets is None or not targets :
9301139 print ("⚠️ 没有匹配的模型" )
9311140 return
9321141
9331142 # 选择运行模式
934- if args .predict_only :
1143+ if getattr (args , 'backtest_only' , False ) is True :
1144+ run_backtest_only (args , targets )
1145+ elif args .predict_only :
9351146 run_predict_only (args , targets , rolling_cfg )
936- elif args .cold_start or args . resume :
1147+ elif args .cold_start or getattr ( args , ' resume' , False ) is True or getattr ( args , 'merge' , False ) is True :
9371148 run_cold_start (args , targets , rolling_cfg )
9381149 else :
9391150 run_daily (args , targets , rolling_cfg )
0 commit comments