Skip to content

Commit 3044b6c

Browse files
author
Tonny@Home
committed
feature: Rolling train: enable merge new model for cold start and add back test for combined predictions
1 parent 9226f1c commit 3044b6c

File tree

2 files changed

+434
-13
lines changed

2 files changed

+434
-13
lines changed

quantpits/scripts/rolling_train.py

Lines changed: 224 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,12 @@ def parse_args():
585585
help='仅使用最新模型预测,不训练')
586586
mode.add_argument('--resume', action='store_true',
587587
help='从断点恢复训练')
588+
mode.add_argument('--merge', action='store_true',
589+
help='合并冷启动:在已有 rolling 状态基础上追加新模型')
590+
mode.add_argument('--backtest', action='store_true',
591+
help='训练拼接完成后,对产出的合成预测进行全量回测')
592+
mode.add_argument('--backtest-only', action='store_true',
593+
help='仅对 latest_rolling_records.json 中的模型进行回测 (跳过训练预测)')
588594

589595
select = parser.add_argument_group('模型选择')
590596
select.add_argument('--models', type=str,
@@ -712,10 +718,14 @@ def run_cold_start(args, targets, rolling_cfg):
712718

713719
# 初始化状态
714720
state = RollingState()
715-
if not args.resume:
721+
if getattr(args, 'resume', False) is not True and getattr(args, 'merge', False) is not True:
716722
state.init_run(rolling_cfg, anchor_date, len(windows))
717723
else:
718-
print("⏩ Resume 模式:跳过已完成的 window×model")
724+
if not state.anchor_date:
725+
print("❌ 无 rolling 状态可恢复,将新建状态")
726+
state.init_run(rolling_cfg, anchor_date, len(windows))
727+
else:
728+
print(f"⏩ {'Merge' if getattr(args, 'merge', False) is True else 'Resume'} 模式:跳过已完成窗格")
719729

720730
rolling_exp_name = f"Rolling_Windows_{freq}"
721731
combined_exp_name = f"Rolling_Combined_{freq}"
@@ -753,13 +763,23 @@ def run_cold_start(args, targets, rolling_cfg):
753763
print(f"❌ {model_name} W{widx} 训练失败: {result.get('error', 'Unknown')}")
754764

755765
# 拼接预测
756-
model_names = list(targets.keys())
766+
if getattr(args, 'merge', False) is True:
767+
completed = state.get_all_completed_windows()
768+
all_models = set(targets.keys())
769+
for win, models in completed.items():
770+
all_models.update(models.keys())
771+
model_names = list(all_models)
772+
else:
773+
model_names = list(targets.keys())
774+
757775
combined_records = concatenate_rolling_predictions(
758776
state, model_names, rolling_exp_name, combined_exp_name, anchor_date
759777
)
760778

761779
if combined_records:
762780
save_rolling_records(combined_records, combined_exp_name, anchor_date)
781+
if getattr(args, 'backtest', False) is True:
782+
run_combined_backtest(model_names, combined_records, combined_exp_name, params_base)
763783

764784
# 完成
765785
print(f"\n{'='*60}")
@@ -843,6 +863,8 @@ def run_daily(args, targets, rolling_cfg):
843863
)
844864
if combined_records:
845865
save_rolling_records(combined_records, combined_exp_name, anchor_date)
866+
if getattr(args, 'backtest', False) is True:
867+
run_combined_backtest(model_names, combined_records, combined_exp_name, params_base)
846868

847869
print(f"\n✅ Rolling 滚动更新完成 (新训练 {len(new_windows)} 个 windows)")
848870

@@ -856,6 +878,193 @@ def run_daily(args, targets, rolling_cfg):
856878
)
857879

858880

881+
def run_combined_backtest(model_names, combined_records, combined_exp_name, params_base):
882+
"""
883+
对合并后的预测执行回测,并将回测结果的 port_analysis 等指标保存追加回相应的记录。
884+
"""
885+
from qlib.workflow import R
886+
from qlib.backtest import backtest
887+
from qlib.backtest.executor import SimulatorExecutor
888+
import strategy
889+
import numpy as np
890+
import warnings
891+
892+
print(f"\n{'='*60}")
893+
print("📈 运行 Rolling 合并预测的回测")
894+
print(f"{'='*60}")
895+
896+
st_config = strategy.load_strategy_config()
897+
bt_config = strategy.get_backtest_config(st_config)
898+
899+
for model_name in model_names:
900+
if model_name not in combined_records:
901+
continue
902+
903+
record_id = combined_records[model_name]
904+
print(f"\n [{model_name}] 提取合并预测以进行回测 (Record: {record_id})...")
905+
906+
try:
907+
rec = R.get_recorder(recorder_id=record_id, experiment_name=combined_exp_name)
908+
pred = rec.load_object("pred.pkl")
909+
910+
if pred is None or pred.empty:
911+
print(f" [{model_name}] 预测为空,跳过回测。")
912+
continue
913+
914+
bt_start = str(pred.index.get_level_values(0).min().date())
915+
bt_end = str(pred.index.get_level_values(0).max().date())
916+
917+
print(f" [{model_name}] Backtest Range: {bt_start} ~ {bt_end}")
918+
919+
# Create Strategy
920+
strategy_inst = strategy.create_backtest_strategy(pred, st_config)
921+
922+
# Create Executor
923+
executor_obj = SimulatorExecutor(
924+
time_per_step=params_base['freq'],
925+
generate_portfolio_metrics=True,
926+
verbose=False
927+
)
928+
929+
print(f" [{model_name}] 执行回测...")
930+
with np.errstate(divide='ignore', invalid='ignore'), warnings.catch_warnings():
931+
warnings.simplefilter("ignore", RuntimeWarning)
932+
raw_portfolio_metrics, raw_indicators = backtest(
933+
executor=executor_obj,
934+
strategy=strategy_inst,
935+
start_time=bt_start,
936+
end_time=bt_end,
937+
account=bt_config['account'],
938+
benchmark=params_base['benchmark'],
939+
exchange_kwargs=bt_config['exchange_kwargs']
940+
)
941+
942+
# Use PortfolioAnalyzer to get traditional metrics
943+
from quantpits.scripts.analysis.portfolio_analyzer import PortfolioAnalyzer
944+
import pandas as pd
945+
from qlib.data import D
946+
947+
def extract_report_df(metrics):
948+
if isinstance(metrics, dict):
949+
val = list(metrics.values())[0]
950+
return val[0] if isinstance(val, tuple) else val
951+
elif isinstance(metrics, tuple):
952+
first = metrics[0]
953+
if isinstance(first, pd.DataFrame):
954+
return first
955+
elif isinstance(first, tuple) and len(first) >= 1:
956+
return first[0]
957+
return metrics
958+
return metrics
959+
960+
report_df = extract_report_df(raw_portfolio_metrics)
961+
if report_df is None or report_df.empty:
962+
print(f" [{model_name}] 提取回测结果失败。")
963+
continue
964+
965+
# Format report DataFrame
966+
da_df = pd.DataFrame(index=report_df.index)
967+
da_df['收盘价值'] = report_df['account']
968+
da_df[params_base['benchmark']] = (1 + report_df['bench']).cumprod()
969+
if not isinstance(da_df.index, pd.DatetimeIndex):
970+
da_df.index = pd.to_datetime(da_df.index)
971+
972+
bt_start_dt = da_df.index.min()
973+
bt_end_dt = da_df.index.max()
974+
daily_dates = D.calendar(start_time=bt_start_dt, end_time=bt_end_dt, freq='day')
975+
da_df = da_df.reindex(daily_dates, method='ffill').dropna(subset=['收盘价值'])
976+
da_df = da_df.reset_index().rename(columns={'index': '成交日期', 'datetime': '成交日期'})
977+
978+
pa = PortfolioAnalyzer(
979+
daily_amount_df=da_df,
980+
trade_log_df=pd.DataFrame(),
981+
holding_log_df=pd.DataFrame(),
982+
benchmark_col=params_base['benchmark'],
983+
freq=params_base['freq']
984+
)
985+
metrics = pa.calculate_traditional_metrics()
986+
987+
ann_ret = metrics.get('CAGR', 0)
988+
max_dd = metrics.get('Max_Drawdown', 0)
989+
excess = metrics.get('Excess_Return_CAGR', 0)
990+
ir = metrics.get('Information_Ratio', 0)
991+
calmar = metrics.get('Calmar', 0)
992+
993+
print(f" [{model_name}] 回测完成! Ann_Ret: {ann_ret:.2%}, Excess: {excess:.2%}, Max_DD: {max_dd:.2%}, IR: {ir:.3f}")
994+
995+
# Save objects back to the same recorder
996+
# By calling methods directly on the existing `rec` object
997+
try:
998+
rec.log_metrics(
999+
Ann_Ret=ann_ret,
1000+
Max_DD=max_dd,
1001+
Excess_Return=excess,
1002+
Information_Ratio=ir,
1003+
Calmar=calmar
1004+
)
1005+
1006+
# 严格按照 Qlib PortAnaRecord 的保存格式,把报告分离并保存到 portfolio_analysis 子目录下
1007+
port_ana_objs = {}
1008+
if isinstance(raw_portfolio_metrics, dict):
1009+
for freq_key, metrics_tuple in raw_portfolio_metrics.items():
1010+
if isinstance(metrics_tuple, tuple) and len(metrics_tuple) >= 2:
1011+
port_ana_objs[f"report_normal_{freq_key}.pkl"] = metrics_tuple[0]
1012+
port_ana_objs[f"positions_normal_{freq_key}.pkl"] = metrics_tuple[1]
1013+
elif isinstance(metrics_tuple, tuple) and len(metrics_tuple) == 1:
1014+
port_ana_objs[f"report_normal_{freq_key}.pkl"] = metrics_tuple[0]
1015+
1016+
if port_ana_objs:
1017+
rec.save_objects(artifact_path="portfolio_analysis", **port_ana_objs)
1018+
1019+
# 指标分析保存到 sig_analysis 子目录(依照 SigAnaRecord)或者根目录
1020+
rec.save_objects(artifact_path="sig_analysis", **{
1021+
f"indicator_analysis_{params_base['freq']}.pkl": raw_indicators
1022+
})
1023+
except Exception as log_e:
1024+
print(f" [{model_name}] MLflow 记录失败,可能已存在同名 metric: {log_e}")
1025+
1026+
except Exception as e:
1027+
print(f" [{model_name}] 回测过程失败: {e}")
1028+
import traceback
1029+
traceback.print_exc()
1030+
1031+
1032+
def run_backtest_only(args, targets):
1033+
"""仅回测模式:读取 latest_rolling_records.json 直接运行"""
1034+
import os, json
1035+
from train_utils import ROLLING_RECORD_FILE
1036+
env.init_qlib()
1037+
params_base = get_base_params()
1038+
1039+
if os.path.exists(ROLLING_RECORD_FILE):
1040+
with open(ROLLING_RECORD_FILE, 'r') as f:
1041+
records = json.load(f)
1042+
else:
1043+
records = None
1044+
1045+
if not records or "models" not in records:
1046+
print("❌ 找不到有效的 latest_rolling_records.json 或内容为空。")
1047+
return
1048+
1049+
combined_exp_name = records.get("experiment_name")
1050+
if not combined_exp_name:
1051+
freq = params_base['freq'].upper()
1052+
combined_exp_name = f"Rolling_Combined_{freq}"
1053+
1054+
combined_records = records["models"]
1055+
model_names = []
1056+
1057+
for m in targets.keys():
1058+
if m in combined_records:
1059+
model_names.append(m)
1060+
1061+
if not model_names:
1062+
print("❌ 选定的模型中没有找到历史滚动预测记录。")
1063+
return
1064+
1065+
run_combined_backtest(model_names, combined_records, combined_exp_name, params_base)
1066+
1067+
8591068
def run_predict_only(args, targets, rolling_cfg):
8601069
"""仅预测模式"""
8611070
env.init_qlib()
@@ -909,31 +1118,33 @@ def main():
9091118
args.tag, args.all_enabled
9101119
])
9111120

912-
if not has_selection and not args.resume:
1121+
if getattr(args, 'resume', False) is True or getattr(args, 'merge', False) is True or getattr(args, 'backtest_only', False) is True:
1122+
if not has_selection:
1123+
args.all_enabled = True
1124+
has_selection = True
1125+
1126+
if not has_selection:
9131127
print("❌ 请指定要训练的模型")
9141128
print(" 使用 --models, --algorithm, --dataset, --tag, 或 --all-enabled")
9151129
return
9161130

917-
if args.resume:
918-
# Resume 模式: 从状态中恢复
1131+
if getattr(args, 'resume', False) is True or getattr(args, 'merge', False) is True:
9191132
state = RollingState()
920-
if not state.anchor_date:
1133+
if not state.anchor_date and getattr(args, 'resume', False) is True:
9211134
print("❌ 无 rolling 状态可恢复")
9221135
return
9231136

924-
if not has_selection:
925-
# 没有指定模型,使用 --all-enabled
926-
args.all_enabled = True
927-
9281137
targets = resolve_target_models(args)
9291138
if targets is None or not targets:
9301139
print("⚠️ 没有匹配的模型")
9311140
return
9321141

9331142
# 选择运行模式
934-
if args.predict_only:
1143+
if getattr(args, 'backtest_only', False) is True:
1144+
run_backtest_only(args, targets)
1145+
elif args.predict_only:
9351146
run_predict_only(args, targets, rolling_cfg)
936-
elif args.cold_start or args.resume:
1147+
elif args.cold_start or getattr(args, 'resume', False) is True or getattr(args, 'merge', False) is True:
9371148
run_cold_start(args, targets, rolling_cfg)
9381149
else:
9391150
run_daily(args, targets, rolling_cfg)

0 commit comments

Comments
 (0)