Skip to content

Commit d5efe25

Browse files
author
Tonny@Home
committed
Refactor pre-training workflow to decouple base models, introduce a standalone pre-training script, and implement dynamic path injection with feature consistency validation.
1 parent d0a07aa commit d5efe25

File tree

10 files changed

+1127
-38
lines changed

10 files changed

+1127
-38
lines changed

docs/01_TRAINING_GUIDE.md

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
| `prod_train_predict.py` | 全量训练+预测 || configs | `latest_train_records.json` |
1010
| `incremental_train.py` | 增量训练+预测 || configs | `latest_train_records.json` |
1111
| `prod_predict_only.py` | 仅预测 || 已有模型 | `latest_train_records.json` |
12+
| `pretrain.py` | 基础模型预训练 || configs | `data/pretrained/` (state_dict) |
1213

1314
两个脚本都会在修改 `latest_train_records.json` 之前自动备份历史到 `data/history/`
1415

@@ -23,6 +24,7 @@ QuantPits/
2324
│ │ ├── prod_train_predict.py # 全量训练脚本
2425
│ │ ├── incremental_train.py # 增量训练脚本
2526
│ │ ├── prod_predict_only.py # 仅预测脚本(不训练)
27+
│ │ ├── pretrain.py # 🧠 基础模型预训练脚本
2628
│ │ ├── check_workflow_yaml.py # 🔧 YAML配置生产环境参数验证
2729
│ │ └── train_utils.py # 共享工具模块
2830
│ └── docs/
@@ -39,6 +41,7 @@ QuantPits/
3941
│ └── model_performance_*.json # 模型成绩
4042
├── data/
4143
│ ├── history/ # 📦 自动备份的历史文件
44+
│ ├── pretrained/ # 🧠 预训练基模型 (.pkl + .json)
4245
│ └── run_state.json # 增量训练运行状态
4346
└── latest_train_records.json # 当前训练记录
4447
```
@@ -59,10 +62,15 @@ models:
5962
market: csi300 # 目标市场(作为元数据标签用于命令行筛选)
6063
yaml_file: config/workflow_config_gru.yaml # Qlib 工作流配置
6164
enabled: true # 是否参与全量训练
62-
tags: [ts] # 分类标签(用于筛选)
65+
tags: [basemodel, ts] # 分类标签(用于筛选)
66+
pretrain_source: lstm_Alpha158 # (可选) 声明依赖的基础模型
6367
notes: "可选备注" # 备注信息
6468
```
6569
70+
#### 关键字段说明:
71+
- **`tags: [basemodel]`**: 标记该模型可作为预训练基础模型。
72+
- **`pretrain_source`**: 标记该上层模型依赖哪个基础模型。系统会自动寻找对应的 `_latest.pkl`。
73+
6674
> [!NOTE]
6775
> **关于市场配置的区别**:注册表中的 `market` 字段是**模型元数据标签**,专门用于在执行增量训练或预测时通过 `--market` 参数进行筛选过滤。实际拉取量价数据时,系统依据的是 `model_config.json` 中的全局 `market` 设置。
6876

@@ -308,17 +316,54 @@ python quantpits/scripts/check_workflow_yaml.py --fix
308316

309317
---
310318

319+
---
320+
321+
## 基础模型预训练 (`pretrain.py`)
322+
323+
某些复杂模型(如 GATs, ADD, IGMTF)需要一个预训练好的基模型(如 LSTM 或 GRU)作为权重初始化。
324+
325+
### 使用场景
326+
- 需要为上层模型提供初始化权重。
327+
- 修改了 Feature (d_feat),需要重新训练兼容的基础模型。
328+
329+
### 核心语义
330+
- **预训练不计入训练记录**:不修改 `latest_train_records.json`
331+
- **元数据校验**:每个预训练文件附带 `.json` 元数据。如果上层模型的 `d_feat` 与预训练文件不符,系统会报错阻断。
332+
333+
### 常用命令
334+
335+
```bash
336+
# 1. 列出可预训练模型及其依赖关系
337+
python quantpits/scripts/pretrain.py --list
338+
339+
# 2. 预训练指定基础模型
340+
python quantpits/scripts/pretrain.py --models lstm_Alpha158
341+
342+
# 3. 为特定上层模型预训练(最推荐:自动对齐 Dataset 配置)
343+
# 即使修改了 Feature,该命令也能确保基础模型与上层模型完全兼容
344+
python quantpits/scripts/pretrain.py --for gats_Alpha158_plus
345+
346+
# 4. 查看已有预训练文件
347+
python quantpits/scripts/pretrain.py --show-pretrained
348+
349+
# 5. 强制使用随机权重(跳过预训练)
350+
# 在 incremental_train 或 prod_predict_only 中均可用
351+
python quantpits/scripts/incremental_train.py --models gats_Alpha158_plus --no-pretrain
352+
```
353+
354+
---
355+
311356
## 关于 LSTM 和 GATs
312357

313-
- `lstm_Alpha158` 模型训练时会自动输出 `csi300_lstm_ts_latest.pkl`
314-
- 该 pkl 文件是 GATs 模型的 `basemodel`
315-
- GATs 模型配置中引用了此文件
316-
- 目前 LSTM 和 GATs 都设为 `enabled: false`
317-
- 如需使用 GATs,需先训练 LSTM
318-
```bash
319-
python quantpits/scripts/incremental_train.py --models lstm_Alpha158
320-
python quantpits/scripts/incremental_train.py --models gats_Alpha158_plus
321-
```
358+
- `gats_Alpha158_plus` 默认依赖 `lstm_Alpha158`
359+
- 训练全流程:
360+
1. 预训练基模型(可选,已有则跳过):
361+
`python quantpits/scripts/pretrain.py --for gats_Alpha158_plus`
362+
2. 训练上层模型
363+
`python quantpits/scripts/incremental_train.py --models gats_Alpha158_plus`
364+
365+
- 如果不想使用预训练模型,只需加上 `--no-pretrain` 标志。
366+
322367

323368
---
324369

docs/en/01_TRAINING_GUIDE.md

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44

55
The training system consists of three main scripts that share the same utility modules and model registry:
66

7-
| Script | Purpose | Save Semantics |
8-
|------|------|----------|
9-
| `prod_train_predict.py` | Full training of all enabled models | **Full Overwrite** of `latest_train_records.json` |
10-
| `incremental_train.py` | Selective training of individual models | **Incremental Merge** to `latest_train_records.json` |
11-
| `prod_predict_only.py` | Prediction only (no training) | **Incremental Merge** to `latest_train_records.json` |
7+
| Script | Purpose | Training | Data Source | Save Semantics |
8+
|------|------|------|--------|----------|
9+
| `prod_train_predict.py` | Full training+predict || configs | `latest_train_records.json` |
10+
| `incremental_train.py` | Incremental training+predict || configs | `latest_train_records.json` |
11+
| `prod_predict_only.py` | Prediction only || Existing models | `latest_train_records.json` |
12+
| `pretrain.py` | Base model pre-training || configs | `data/pretrained/` (state_dict) |
1213

1314
Both scripts automatically back up the history to `data/history/` before modifying `latest_train_records.json`.
1415

@@ -23,6 +24,7 @@ QuantPits/
2324
│ │ ├── prod_train_predict.py # Full training script
2425
│ │ ├── incremental_train.py # Incremental training script
2526
│ │ ├── prod_predict_only.py # Prediction-only script (no training)
27+
│ │ ├── pretrain.py # 🧠 Base model pre-training script
2628
│ │ ├── check_workflow_yaml.py # 🔧 YAML config production validation & fix
2729
│ │ └── train_utils.py # Shared utility module
2830
│ └── docs/
@@ -39,6 +41,7 @@ QuantPits/
3941
│ └── model_performance_*.json # Model performance metrics (IC/ICIR)
4042
├── data/
4143
│ ├── history/ # 📦 Auto-backed up historical files
44+
│ ├── pretrained/ # 🧠 Pre-trained base models (.pkl + .json)
4245
│ └── run_state.json # State tracker for incremental training
4346
└── latest_train_records.json # Current training records
4447
```
@@ -59,10 +62,15 @@ models:
5962
market: csi300 # Target market (Metadata tag used for CLI filtering)
6063
yaml_file: config/workflow_config_gru.yaml # Qlib workflow config
6164
enabled: true # Whether to participate in full training
62-
tags: [ts] # Classification tags (for filtering)
65+
tags: [basemodel, ts] # Classification tags (for filtering)
66+
pretrain_source: lstm_Alpha158 # (Optional) Declare dependency on base model
6367
notes: "Optional notes" # Notes
6468
```
6569
70+
#### Key Fields:
71+
- **`tags: [basemodel]`**: Marks the model as a pre-trainable base model.
72+
- **`pretrain_source`**: Tells the system which base model this upper-layer model depends on. The system will automatically look for the corresponding `_latest.pkl`.
73+
6674
> [!NOTE]
6775
> **Distinction of Market Configurations**: The `market` field in the registry acts strictly as a **Model Metadata Tag** intended for CLI selection filtering via `--market` during incremental training or predictions. Actual data extraction bounds are perpetually steered by the global `market` setting inside `model_config.json`.
6876

@@ -308,17 +316,54 @@ python quantpits/scripts/check_workflow_yaml.py --fix
308316

309317
---
310318

319+
---
320+
321+
## Base Model Pre-training (`pretrain.py`)
322+
323+
Complex models (e.g., GATs, ADD, IGMTF) require a pre-trained base model (e.g., LSTM or GRU) for weight initialization.
324+
325+
### Usage Scenarios
326+
- Providing initialization weights for upper-layer models.
327+
- When features (d_feat) are modified, requiring new compatible base models.
328+
329+
### Core Semantics
330+
- **Pre-training is not logged in records**: It does not modify `latest_train_records.json`.
331+
- **Metadata Validation**: Each pre-trained file comes with a `.json` metadata file. If an upper model's `d_feat` doesn't match the pre-trained file, training will be blocked.
332+
333+
### Common Commands
334+
335+
```bash
336+
# 1. List pre-trainable models and dependencies
337+
python quantpits/scripts/pretrain.py --list
338+
339+
# 2. Pre-train a specific base model
340+
python quantpits/scripts/pretrain.py --models lstm_Alpha158
341+
342+
# 3. Pre-train FOR a specific upper model (Recommended: Aligns dataset config)
343+
# This ensures perfect compatibility even if features are modified.
344+
python quantpits/scripts/pretrain.py --for gats_Alpha158_plus
345+
346+
# 4. Show existing pre-trained files
347+
python quantpits/scripts/pretrain.py --show-pretrained
348+
349+
# 5. Force random weights (Skip pre-training)
350+
# Available in both incremental_train and prod_predict_only
351+
python quantpits/scripts/incremental_train.py --models gats_Alpha158_plus --no-pretrain
352+
```
353+
354+
---
355+
311356
## Concerning LSTM and GATs
312357

313-
- The `lstm_Alpha158` model automatically outputs `csi300_lstm_ts_latest.pkl` upon training.
314-
- This `.pkl` is a required `basemodel` for GATs.
315-
- GATs configurations implicitly reference this file.
316-
- Both LSTM and GATs are presently defaulted to `enabled: false`.
317-
- If GATs is desired, the LSTM must be trained chronologically prior:
318-
```bash
319-
python quantpits/scripts/incremental_train.py --models lstm_Alpha158
320-
python quantpits/scripts/incremental_train.py --models gats_Alpha158_plus
321-
```
358+
- `gats_Alpha158_plus` depends on `lstm_Alpha158` by default.
359+
- Full Workflow:
360+
1. Pre-train base model (Optional if already exists):
361+
`python quantpits/scripts/pretrain.py --for gats_Alpha158_plus`
362+
2. Train upper model:
363+
`python quantpits/scripts/incremental_train.py --models gats_Alpha158_plus`
364+
365+
- To bypass pre-training and use random weights, use the `--no-pretrain` flag.
366+
322367

323368
---
324369

quantpits/scripts/incremental_train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def parse_args():
100100
help='仅打印待训练模型列表,不实际训练')
101101
ctrl.add_argument('--experiment-name', type=str, default=None,
102102
help='MLflow 实验名称 (默认: Prod_Train_{FREQ})')
103+
ctrl.add_argument('--no-pretrain', action='store_true',
104+
help='忽略 pretrain_source,使用随机权重初始化 basemodel')
103105

104106
# 信息查看
105107
info = parser.add_argument_group('信息查看')
@@ -260,7 +262,10 @@ def run_incremental_train(args):
260262

261263
yaml_file = model_info['yaml_file']
262264

263-
result = train_single_model(model_name, yaml_file, params, experiment_name)
265+
result = train_single_model(
266+
model_name, yaml_file, params, experiment_name,
267+
no_pretrain=args.no_pretrain
268+
)
264269

265270
if result['success']:
266271
new_records['models'][model_name] = result['record_id']

quantpits/scripts/plot_model_opinions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,13 @@ def main():
6767
# X 轴为各个模型/Combo,Y 轴为排名,每一根线代表一只股票
6868
plt.figure(figsize=(14, 8))
6969

70+
plotted = False
7071
for instrument in rank_df.index:
7172
y_values = rank_df.loc[instrument]
7273
if y_values.isna().all():
7374
continue
7475
plt.plot(rank_df.columns, y_values, marker='o', alpha=0.7, label=instrument)
76+
plotted = True
7577

7678
# Y轴刻度反转,使得排名第 1 的在最上面
7779
plt.gca().invert_yaxis()
@@ -83,7 +85,8 @@ def main():
8385
plt.title(f'Model Prediction Rank Comparison - {os.path.basename(csv_file)}')
8486

8587
# Legend outside
86-
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0., title="Instrument")
88+
if plotted:
89+
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0., title="Instrument")
8790

8891
plt.tight_layout()
8992

0 commit comments

Comments
 (0)