A Modular Adapter Routing Framework for Multi-Task Learning
Plug-in Expert 是一个模块化的多任务学习框架,融合了 LoRA、QLoRA、全参数微调等多种参数高效微调技术,并通过灵活的 硬路由(Hard Router) 与 软路由(Soft Router) 机制,实现任务感知的专家适配器组合与动态推理。该框架设计旨在为多任务设置提供高效、可扩展且可解释的适配解决方案。
该项目在 MMLU 多项选择分类任务上取得了超越传统全参微调与普通 LoRA 的表现,且支持后续扩展至多种真实任务。
- 支持三种适配器机制:
LoRA
QLoRA
(低比特权重量化)全参数微调
- 提供三种路由策略:
- 硬路由(Hard Routing):手动指定 / BERT 分类模型预测
- 软路由(Soft Routing):基于置信度和先验融合
- 多数投票(Majority Voting)
- 支持分类任务(如:IFlytek)与生成任务(MMLU)
- 完整支持实验可复现:支持多配置训练、路由策略消融与数据增强(COT 蒸馏)
FlexLoRA/
├── classification/
│ └── scripts/
│ ├── train_lora.py
│ ├── train_sft.py
│ ├── infer_lora.py
│ └── infer_sft.py
├── generation/
│ └── scripts/
│ ├── train_lora_gen_multiclasses.py
│ ├── infer_lora_gen_hardrouter.py
│ ├── infer_lora_gen_softrouter.py
│ └── infer_lora_gen_majorvoting.py
├── main.py # 统一入口:支持train/infer +
└── ...
本项目使用的 MMLU (Massive Multitask Language Understanding) 数据集包含以下文件类型:
mmlu_5class_balanced_1000_gen.json
: 包含 2000 个泛化数据样本,涵盖 5 个学科类别(STEM、Business、Social Science、Psychology、Biomedical)的平衡分布数据,用于模型泛化能力测试。
mmlu_train_gen_1000_[Domain].json
: 每个学科领域包含 1000 个特化训练样本,其中[Domain]
包括:STEM
: 科学、技术、工程、数学领域Business
: 商业管理领域SocialScience
: 社会科学领域Psychology
: 心理学领域Biomedical
: 生物医学领域
mmlu_train_cot_[Domain].json
: 利用 DeepSeek 模型增强的思维链数据,包含关于答案推理过程的详细说明,帮助模型理解解题思路和逻辑推理过程。
mmlu_test_gen.json
: 用于模型性能评估的测试数据集,包含 2500 个测试样本。
router_train.json
: 用于训练 BERT 路由器的训练数据router_test.json
: 用于测试 BERT 路由器的测试数据
训练完成后,各种模型会自动保存到以下位置:
- LoRA 模型:
classification/model_path/
- 保存格式:
epoch_{epoch}.pt
(每个epoch的检查点) - 包含:LoRA权重、分类层参数和训练配置
- 保存格式:
- 全参数微调模型:
model_path/iflytek_model_sft/final_model/
- 保存完整的模型权重和tokenizer
- 多类别 LoRA 模型:
../model_path/mmlu_lora/
- 每个学科领域单独保存:
model_mmlu_1000_{Domain}/
- 支持的领域:
STEM
,Business
,SocialScience
,Psychology
,Biomedical
- 每个学科领域单独保存:
- COT 增强模型:
../model_path/mmlu_lora/
- 保存格式:
model_mmlu_cot_{Domain}/
- 保存格式:
- BERT 路由器:
./bert_subject_router/
- 用于硬路由策略的学科分类器
model_path/
├── mmlu_lora/
│ ├── model_mmlu_1000_STEM/
│ ├── model_mmlu_1000_Business/
│ ├── model_mmlu_1000_SocialScience/
│ ├── model_mmlu_1000_Psychology/
│ ├── model_mmlu_1000_Biomedical/
│ └── model_mmlu_cot_*/ # COT增强版本
├── iflytek_model_sft/
│ └── final_model/
└── bert_subject_router/
推荐使用 Python 3.10+,建议创建虚拟环境:
conda create -n flexlora python=3.10
conda activate flexlora
pip install -r requirements.txt
依赖框架:
transformers
,datasets
,scikit-learn
,numpy
,torch
,accelerate
等
python main.py --task Classification --mode train --training_method lora
# 或使用 QLoRA
python main.py --task Classification --mode train --training_method qlora
python main.py --task Classification --mode infer --training_method lora
python main.py --task Generation --mode train
**使用全参数微调(可选)**
python generation/scripts/train_lora_gen.py
训练 BERT 路由器(可选)
python generation/scripts/classifybert.py
硬路由:
python main.py --task Generation --mode infer --inference_method hard_route
软路由:
python main.py --task Generation --mode infer --inference_method soft_route --soft_route_method prior&confidence_fusion
多数投票:
python main.py --task Generation --mode infer --inference_method soft_route --soft_route_method major_voting
方法 | Overall | STEM | Business | SocialScience | Psychology | Biomedical |
---|---|---|---|---|---|---|
Baseline | 0.572 | 0.420 | 0.600 | 0.590 | 0.680 | 0.570 |
+LoRA | 0.588 | 0.460 | 0.630 | 0.630 | 0.670 | 0.550 |
Hard Router | 0.626 | 0.510 | 0.650 | 0.700 | 0.700 | 0.570 |
Soft Router | 0.620 | 0.520 | 0.640 | 0.670 | 0.700 | 0.570 |
Major Voting | 0.612 | 0.480 | 0.650 | 0.670 | 0.690 | 0.570 |