@@ -96,7 +96,8 @@ def get_cashflow_today(cashflow_config, anchor_date):
9696# ============================================================================
9797# Stage 1: 加载预测数据
9898# ============================================================================
99- def load_predictions (prediction_file = None , model_name = None , anchor_date = None ):
99+ def load_predictions (prediction_file = None , model_name = None , anchor_date = None ,
100+ prediction_dir = None ):
100101 """
101102 加载预测数据。
102103
@@ -120,7 +121,8 @@ def load_predictions(prediction_file=None, model_name=None, anchor_date=None):
120121
121122 if model_name :
122123 # 按模型名搜索
123- pattern = os .path .join (PREDICTION_DIR , f"{ model_name } _*.csv" )
124+ _pred_dir = prediction_dir or PREDICTION_DIR
125+ pattern = os .path .join (_pred_dir , f"{ model_name } _*.csv" )
124126 files = sorted (glob .glob (pattern ))
125127 if not files :
126128 raise FileNotFoundError (
@@ -136,24 +138,25 @@ def load_predictions(prediction_file=None, model_name=None, anchor_date=None):
136138 # 优先级: ensemble_YYYY-MM-DD.csv (default combo 的向后兼容副本)
137139 # > ensemble_default_YYYY-MM-DD.csv (显式 default combo)
138140 # > ensemble_*.csv (任意 combo)
141+ _pred_dir = prediction_dir or PREDICTION_DIR
139142 pred_file = None
140143
141144 # 1) 向后兼容格式: ensemble_YYYY-MM-DD.csv (无 combo 名)
142- compat_pattern = os .path .join (PREDICTION_DIR , "ensemble_[0-9]*.csv" )
145+ compat_pattern = os .path .join (_pred_dir , "ensemble_[0-9]*.csv" )
143146 compat_files = sorted (glob .glob (compat_pattern ))
144147 if compat_files :
145148 pred_file = compat_files [- 1 ]
146149
147150 # 2) 若无,尝试 ensemble_default_YYYY-MM-DD.csv
148151 if not pred_file :
149- default_pattern = os .path .join (PREDICTION_DIR , "ensemble_default_*.csv" )
152+ default_pattern = os .path .join (_pred_dir , "ensemble_default_*.csv" )
150153 default_files = sorted (glob .glob (default_pattern ))
151154 if default_files :
152155 pred_file = default_files [- 1 ]
153156
154157 # 3) 若仍无,回退到任意 ensemble_*.csv(按日期排序)
155158 if not pred_file :
156- pattern = os .path .join (PREDICTION_DIR , "ensemble_*.csv" )
159+ pattern = os .path .join (_pred_dir , "ensemble_*.csv" )
157160 files = sorted (glob .glob (pattern ))
158161 if not files :
159162 raise FileNotFoundError (
@@ -295,7 +298,8 @@ def _load_pred_latest_day(pred_source, source_type, valid_instruments=None):
295298def generate_model_opinions (focus_instruments , current_holding_instruments ,
296299 top_k , drop_n , buy_suggestion_factor ,
297300 sorted_df , output_dir , next_trade_date_string ,
298- dry_run = False ):
301+ dry_run = False , record_file = None ,
302+ prediction_dir = None ):
299303 """
300304 加载所有 combo 和单一模型的预测,对每个标的生成判断。
301305
@@ -343,13 +347,13 @@ def generate_model_opinions(focus_instruments, current_holding_instruments,
343347 # 1) Combo 预测
344348 for combo_name , cfg in combos .items ():
345349 combo_info [combo_name ] = cfg .get ('models' , [])
346- pattern = os .path .join (PREDICTION_DIR , f"ensemble_{ combo_name } _*.csv" )
350+ pattern = os .path .join (prediction_dir or PREDICTION_DIR , f"ensemble_{ combo_name } _*.csv" )
347351 files = sorted (glob .glob (pattern ))
348352 if files :
349353 sources .append ((f"combo_{ combo_name } " , files [- 1 ], 'combo' , combo_name ))
350354 continue
351355 if cfg .get ('default' , False ):
352- pattern2 = os .path .join (PREDICTION_DIR , "ensemble_*.csv" )
356+ pattern2 = os .path .join (prediction_dir or PREDICTION_DIR , "ensemble_*.csv" )
353357 generic_files = []
354358 for f_path in sorted (glob .glob (pattern2 )):
355359 basename = os .path .basename (f_path )
@@ -364,13 +368,13 @@ def generate_model_opinions(focus_instruments, current_holding_instruments,
364368 for cfg in combos .values ():
365369 all_single_models .update (cfg .get ('models' , []))
366370 for model_name in sorted (all_single_models ):
367- pattern = os .path .join (PREDICTION_DIR , f"{ model_name } _*.csv" )
371+ pattern = os .path .join (prediction_dir or PREDICTION_DIR , f"{ model_name } _*.csv" )
368372 files = sorted (glob .glob (pattern ))
369373 if files :
370374 sources .append ((f"model_{ model_name } " , files [- 1 ], 'model' , model_name ))
371375 else :
372376 try :
373- train_records_file = os .path .join (ROOT_DIR , 'config' , 'latest_train_records.json' )
377+ train_records_file = record_file or os .path .join (ROOT_DIR , 'config' , 'latest_train_records.json' )
374378 if os .path .exists (train_records_file ):
375379 with open (train_records_file , 'r' ) as f :
376380 train_records = json .load (f )
@@ -602,6 +606,11 @@ def main():
602606 help = '直接指定预测文件路径' )
603607 parser .add_argument ('--output-dir' , type = str , default = 'output' ,
604608 help = '输出目录 (默认 output)' )
609+ parser .add_argument ('--prediction-dir' , type = str , default = None ,
610+ help = '预测文件搜索目录 (默认 output/predictions)' )
611+ parser .add_argument ('--record-file' , type = str , default = None ,
612+ help = '训练记录文件,用于加载单模型 PKL 预测 '
613+ '(默认 config/latest_train_records.json)' )
605614 parser .add_argument ('--dry-run' , action = 'store_true' ,
606615 help = '仅打印订单计划,不写入文件' )
607616 parser .add_argument ('--verbose' , action = 'store_true' ,
@@ -659,7 +668,8 @@ def main():
659668 pred_df , source_desc = load_predictions (
660669 prediction_file = args .prediction_file ,
661670 model_name = args .model ,
662- anchor_date = anchor_date
671+ anchor_date = anchor_date ,
672+ prediction_dir = args .prediction_dir
663673 )
664674
665675 print (f"预测来源 : { source_desc } " )
@@ -740,7 +750,8 @@ def main():
740750 opinions_df , combo_info = generate_model_opinions (
741751 focus_instruments , current_holding_instruments ,
742752 top_k , drop_n , buy_suggestion_factor ,
743- sorted_df , args .output_dir , next_trade_date_string , dry_run = args .dry_run
753+ sorted_df , args .output_dir , next_trade_date_string , dry_run = args .dry_run ,
754+ record_file = args .record_file , prediction_dir = args .prediction_dir
744755 )
745756
746757 if opinions_df is not None and not opinions_df .empty and args .verbose :
0 commit comments