Skip to content

Commit d2d5146

Browse files
author
Tonny@Home
committed
fix: Decoupled Downstream Scripts from Training Modes
1 parent 9d1f947 commit d2d5146

File tree

5 files changed

+81
-44
lines changed

5 files changed

+81
-44
lines changed

quantpits/scripts/ensemble_fusion.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -491,31 +491,35 @@ def generate_ensemble_signal(norm_df, final_weights, static_weights, is_dynamic)
491491
# ============================================================================
492492
def save_predictions(final_score, anchor_date, experiment_name, method,
493493
model_names, model_metrics, static_weights, is_dynamic,
494-
output_dir, combo_name=None, is_default=False):
494+
output_dir, combo_name=None, is_default=False,
495+
prediction_dir=None):
495496
"""
496497
保存融合预测和配置。
497498
498499
Args:
500+
output_dir: 配置/报告输出目录
499501
combo_name: 组合名称(多组合模式下使用)
500502
is_default: 是否为 default combo(额外保存不带 combo_name 的兼容文件)
503+
prediction_dir: 预测 CSV 输出目录 (默认 output/predictions)
501504
"""
502505
# 保存预测
503-
os.makedirs("output/predictions", exist_ok=True)
506+
pred_dir = prediction_dir or os.path.join("output", "predictions")
507+
os.makedirs(pred_dir, exist_ok=True)
504508
ensemble_df = final_score.to_frame('score')
505509

506510
# 文件命名:带 combo_name 或不带
507511
if combo_name:
508-
pred_file = f"output/predictions/ensemble_{combo_name}_{anchor_date}.csv"
512+
pred_file = os.path.join(pred_dir, f"ensemble_{combo_name}_{anchor_date}.csv")
509513
else:
510-
pred_file = f"output/predictions/ensemble_{anchor_date}.csv"
514+
pred_file = os.path.join(pred_dir, f"ensemble_{anchor_date}.csv")
511515

512516
ensemble_df.to_csv(pred_file)
513517
print(f"\nEnsemble 预测已保存: {pred_file}")
514518
print(f"Total: {len(ensemble_df)} records")
515519

516520
# default combo 额外保存一份兼容文件
517521
if combo_name and is_default:
518-
compat_file = f"output/predictions/ensemble_{anchor_date}.csv"
522+
compat_file = os.path.join(pred_dir, f"ensemble_{anchor_date}.csv")
519523
ensemble_df.to_csv(compat_file)
520524
print(f"Default 兼容文件: {compat_file}")
521525

@@ -1322,10 +1326,12 @@ def run_single_combo(combo_name, selected_models, method, manual_weights_str,
13221326
)
13231327

13241328
# ---- Stage 5: 保存预测 ----
1329+
prediction_dir = getattr(args, 'prediction_dir', None)
13251330
pred_file = save_predictions(
13261331
final_score, anchor_date, experiment_name, method,
13271332
combo_models, combo_metrics, static_weights, is_dynamic,
1328-
combo_output_dir, combo_name=combo_name, is_default=is_default
1333+
combo_output_dir, combo_name=combo_name, is_default=is_default,
1334+
prediction_dir=prediction_dir
13291335
)
13301336

13311337
# ---- Stage 6: 回测 ----
@@ -1418,6 +1424,8 @@ def main():
14181424
help='训练记录文件 (默认 latest_train_records.json)')
14191425
parser.add_argument('--output-dir', type=str, default='output/ensemble',
14201426
help='输出目录 (默认 output/ensemble)')
1427+
parser.add_argument('--prediction-dir', type=str, default=None,
1428+
help='预测 CSV 输出目录 (默认 output/predictions)')
14211429
parser.add_argument('--no-backtest', action='store_true',
14221430
help='跳过回测')
14231431
parser.add_argument('--no-charts', action='store_true',

quantpits/scripts/order_gen.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def get_cashflow_today(cashflow_config, anchor_date):
9696
# ============================================================================
9797
# Stage 1: 加载预测数据
9898
# ============================================================================
99-
def load_predictions(prediction_file=None, model_name=None, anchor_date=None):
99+
def load_predictions(prediction_file=None, model_name=None, anchor_date=None,
100+
prediction_dir=None):
100101
"""
101102
加载预测数据。
102103
@@ -120,7 +121,8 @@ def load_predictions(prediction_file=None, model_name=None, anchor_date=None):
120121

121122
if model_name:
122123
# 按模型名搜索
123-
pattern = os.path.join(PREDICTION_DIR, f"{model_name}_*.csv")
124+
_pred_dir = prediction_dir or PREDICTION_DIR
125+
pattern = os.path.join(_pred_dir, f"{model_name}_*.csv")
124126
files = sorted(glob.glob(pattern))
125127
if not files:
126128
raise FileNotFoundError(
@@ -136,24 +138,25 @@ def load_predictions(prediction_file=None, model_name=None, anchor_date=None):
136138
# 优先级: ensemble_YYYY-MM-DD.csv (default combo 的向后兼容副本)
137139
# > ensemble_default_YYYY-MM-DD.csv (显式 default combo)
138140
# > ensemble_*.csv (任意 combo)
141+
_pred_dir = prediction_dir or PREDICTION_DIR
139142
pred_file = None
140143

141144
# 1) 向后兼容格式: ensemble_YYYY-MM-DD.csv (无 combo 名)
142-
compat_pattern = os.path.join(PREDICTION_DIR, "ensemble_[0-9]*.csv")
145+
compat_pattern = os.path.join(_pred_dir, "ensemble_[0-9]*.csv")
143146
compat_files = sorted(glob.glob(compat_pattern))
144147
if compat_files:
145148
pred_file = compat_files[-1]
146149

147150
# 2) 若无,尝试 ensemble_default_YYYY-MM-DD.csv
148151
if not pred_file:
149-
default_pattern = os.path.join(PREDICTION_DIR, "ensemble_default_*.csv")
152+
default_pattern = os.path.join(_pred_dir, "ensemble_default_*.csv")
150153
default_files = sorted(glob.glob(default_pattern))
151154
if default_files:
152155
pred_file = default_files[-1]
153156

154157
# 3) 若仍无,回退到任意 ensemble_*.csv(按日期排序)
155158
if not pred_file:
156-
pattern = os.path.join(PREDICTION_DIR, "ensemble_*.csv")
159+
pattern = os.path.join(_pred_dir, "ensemble_*.csv")
157160
files = sorted(glob.glob(pattern))
158161
if not files:
159162
raise FileNotFoundError(
@@ -295,7 +298,8 @@ def _load_pred_latest_day(pred_source, source_type, valid_instruments=None):
295298
def generate_model_opinions(focus_instruments, current_holding_instruments,
296299
top_k, drop_n, buy_suggestion_factor,
297300
sorted_df, output_dir, next_trade_date_string,
298-
dry_run=False):
301+
dry_run=False, record_file=None,
302+
prediction_dir=None):
299303
"""
300304
加载所有 combo 和单一模型的预测,对每个标的生成判断。
301305
@@ -343,13 +347,13 @@ def generate_model_opinions(focus_instruments, current_holding_instruments,
343347
# 1) Combo 预测
344348
for combo_name, cfg in combos.items():
345349
combo_info[combo_name] = cfg.get('models', [])
346-
pattern = os.path.join(PREDICTION_DIR, f"ensemble_{combo_name}_*.csv")
350+
pattern = os.path.join(prediction_dir or PREDICTION_DIR, f"ensemble_{combo_name}_*.csv")
347351
files = sorted(glob.glob(pattern))
348352
if files:
349353
sources.append((f"combo_{combo_name}", files[-1], 'combo', combo_name))
350354
continue
351355
if cfg.get('default', False):
352-
pattern2 = os.path.join(PREDICTION_DIR, "ensemble_*.csv")
356+
pattern2 = os.path.join(prediction_dir or PREDICTION_DIR, "ensemble_*.csv")
353357
generic_files = []
354358
for f_path in sorted(glob.glob(pattern2)):
355359
basename = os.path.basename(f_path)
@@ -364,13 +368,13 @@ def generate_model_opinions(focus_instruments, current_holding_instruments,
364368
for cfg in combos.values():
365369
all_single_models.update(cfg.get('models', []))
366370
for model_name in sorted(all_single_models):
367-
pattern = os.path.join(PREDICTION_DIR, f"{model_name}_*.csv")
371+
pattern = os.path.join(prediction_dir or PREDICTION_DIR, f"{model_name}_*.csv")
368372
files = sorted(glob.glob(pattern))
369373
if files:
370374
sources.append((f"model_{model_name}", files[-1], 'model', model_name))
371375
else:
372376
try:
373-
train_records_file = os.path.join(ROOT_DIR, 'config', 'latest_train_records.json')
377+
train_records_file = record_file or os.path.join(ROOT_DIR, 'config', 'latest_train_records.json')
374378
if os.path.exists(train_records_file):
375379
with open(train_records_file, 'r') as f:
376380
train_records = json.load(f)
@@ -602,6 +606,11 @@ def main():
602606
help='直接指定预测文件路径')
603607
parser.add_argument('--output-dir', type=str, default='output',
604608
help='输出目录 (默认 output)')
609+
parser.add_argument('--prediction-dir', type=str, default=None,
610+
help='预测文件搜索目录 (默认 output/predictions)')
611+
parser.add_argument('--record-file', type=str, default=None,
612+
help='训练记录文件,用于加载单模型 PKL 预测 '
613+
'(默认 config/latest_train_records.json)')
605614
parser.add_argument('--dry-run', action='store_true',
606615
help='仅打印订单计划,不写入文件')
607616
parser.add_argument('--verbose', action='store_true',
@@ -659,7 +668,8 @@ def main():
659668
pred_df, source_desc = load_predictions(
660669
prediction_file=args.prediction_file,
661670
model_name=args.model,
662-
anchor_date=anchor_date
671+
anchor_date=anchor_date,
672+
prediction_dir=args.prediction_dir
663673
)
664674

665675
print(f"预测来源 : {source_desc}")
@@ -740,7 +750,8 @@ def main():
740750
opinions_df, combo_info = generate_model_opinions(
741751
focus_instruments, current_holding_instruments,
742752
top_k, drop_n, buy_suggestion_factor,
743-
sorted_df, args.output_dir, next_trade_date_string, dry_run=args.dry_run
753+
sorted_df, args.output_dir, next_trade_date_string, dry_run=args.dry_run,
754+
record_file=args.record_file, prediction_dir=args.prediction_dir
744755
)
745756

746757
if opinions_df is not None and not opinions_df.empty and args.verbose:

quantpits/scripts/signal_ranking.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,21 @@
5959
# ============================================================================
6060
# 配置解析 (复用 ensemble_fusion.py 的逻辑)
6161
# ============================================================================
62-
def parse_ensemble_config():
62+
def parse_ensemble_config(config_file=None):
6363
"""
6464
解析 ensemble_config.json,兼容新旧格式。
6565
66+
Args:
67+
config_file: 配置文件路径 (默认 ENSEMBLE_CONFIG_FILE)
68+
6669
Returns:
6770
combos: dict, combo_name -> {"models": [], "method": str, "default": bool}
6871
"""
69-
if not os.path.exists(ENSEMBLE_CONFIG_FILE):
72+
_config_file = config_file or ENSEMBLE_CONFIG_FILE
73+
if not os.path.exists(_config_file):
7074
return {}
7175

72-
with open(ENSEMBLE_CONFIG_FILE, 'r') as f:
76+
with open(_config_file, 'r') as f:
7377
config = json.load(f)
7478

7579
if 'combos' in config:
@@ -144,22 +148,24 @@ def generate_signal_scores(pred_df, top_n=300):
144148
return output_df, latest_date
145149

146150

147-
def find_prediction_file(combo_name=None, anchor_date=None):
151+
def find_prediction_file(combo_name=None, anchor_date=None, prediction_dir=None):
148152
"""
149153
查找预测文件。
150154
151155
Args:
152156
combo_name: combo 名称,None 表示查找 default ensemble
153157
anchor_date: 日期限制
158+
prediction_dir: 预测文件搜索目录 (默认 PREDICTION_DIR)
154159
155160
Returns:
156161
pred_file: 文件路径
157162
"""
163+
_pred_dir = prediction_dir or PREDICTION_DIR
158164
if combo_name:
159-
pattern = os.path.join(PREDICTION_DIR, f"ensemble_{combo_name}_*.csv")
165+
pattern = os.path.join(_pred_dir, f"ensemble_{combo_name}_*.csv")
160166
else:
161167
# 查找不带 combo name 的通用 ensemble 文件
162-
pattern = os.path.join(PREDICTION_DIR, "ensemble_*.csv")
168+
pattern = os.path.join(_pred_dir, "ensemble_*.csv")
163169

164170
files = sorted(glob.glob(pattern))
165171

@@ -228,6 +234,8 @@ def main():
228234
help='输出 Top N 个标的 (默认 300)')
229235
parser.add_argument('--output-dir', type=str, default='output/ranking',
230236
help='输出目录 (默认 output/ranking)')
237+
parser.add_argument('--prediction-dir', type=str, default=None,
238+
help='预测文件搜索目录 (默认 output/predictions)')
231239
parser.add_argument('--dry-run', action='store_true',
232240
help='仅打印,不写入文件')
233241
args = parser.parse_args()
@@ -259,7 +267,8 @@ def main():
259267
sys.exit(1)
260268
for name, cfg in combos.items():
261269
try:
262-
pred_file = find_prediction_file(combo_name=name)
270+
pred_file = find_prediction_file(combo_name=name,
271+
prediction_dir=args.prediction_dir)
263272
tasks.append((name, pred_file))
264273
except FileNotFoundError as e:
265274
print(f"Warning: {e}")
@@ -270,12 +279,13 @@ def main():
270279
print(f"\n多组合模式: 共 {len(tasks)} 个 combo")
271280

272281
elif args.combo:
273-
pred_file = find_prediction_file(combo_name=args.combo)
282+
pred_file = find_prediction_file(combo_name=args.combo,
283+
prediction_dir=args.prediction_dir)
274284
tasks.append((args.combo, pred_file))
275285

276286
else:
277287
# Default: 使用最新 ensemble 预测
278-
pred_file = find_prediction_file()
288+
pred_file = find_prediction_file(prediction_dir=args.prediction_dir)
279289
tasks.append(('default', pred_file))
280290

281291
# ---- 逐任务处理 ----

tests/quantpits/scripts/test_order_gen.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,8 @@ def test_generate_model_opinions_from_qlib_recorder(mock_R, mock_env, tmp_path):
384384
with patch('quantpits.scripts.order_gen.ROOT_DIR', str(workspace)):
385385
opinions_df, combo_info = order_gen.generate_model_opinions(
386386
["A"], [], top_k=1, drop_n=0, buy_suggestion_factor=1,
387-
sorted_df=sorted_df, output_dir=str(tmp_path), next_trade_date_string="2020-01-01"
387+
sorted_df=sorted_df, output_dir=str(tmp_path), next_trade_date_string="2020-01-01",
388+
record_file=str(train_records_file)
388389
)
389390

390391
assert "model_gru" in opinions_df.columns
@@ -432,7 +433,8 @@ def test_main_dry_run_full(mock_D, mock_safeguard, mock_price, mock_pred, mock_c
432433
mock_D.calendar.return_value = [pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-02")]
433434

434435
import sys
435-
with patch.object(sys, 'argv', ['script.py', '--dry-run', '--verbose']):
436+
with patch.object(sys, 'argv', ['script.py', '--dry-run', '--verbose',
437+
'--prediction-dir', str(workspace / "output" / "predictions")]):
436438
order_gen.main()
437439

438440
# Check if a few key print messages were hit (via capsys if we had it, but mostly we want coverage)

0 commit comments

Comments
 (0)