Skip to content

Commit b9d98fd

Browse files
committed
feat(shortgpt): add new script and polish existing code (typos, paths)
1 parent b8f2101 commit b9d98fd

File tree

9 files changed

+523
-59
lines changed

9 files changed

+523
-59
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,4 @@ autogen/
140140
#fp8
141141
ops/csrc/fp8/deep_gemm/include/cutlass
142142
ops/csrc/fp8/deep_gemm/include/cute
143-
.ccls-cache
143+
.ccls-cache

slm/pipelines/examples/contrastive_training/README.md

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
# 向量检索模型训练
22

3-
推荐安装 gpu 版本的[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/conda/linux-conda.html),以 cuda12.3的 paddle 为例,安装命令如下:
3+
推荐安装 gpu 版本的[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/conda/linux-conda.html),以 cuda11.8的 paddle 为例,安装命令如下:
44

55
```
6-
conda install nccl -c conda-forge
7-
conda install paddlepaddle-gpu==3.0.0rc1 -i https://www.paddlepaddle.org.cn/packages/stable/cu123/ -c conda-forge
8-
```
9-
安装其他依赖:
10-
```
11-
pip install git+https://github.com/PaddlePaddle/PaddleNLP.git@develop
12-
pip install -r requirements.txt
6+
# 创建一个名为 paddle_env 的新环境,并激活
7+
conda create --name paddle_env python=3.10
8+
conda activate paddle_env
9+
10+
# 安装 paddlenlp develop版本
11+
pip install --pre --upgrade paddlenlp -f https://www.paddlepaddle.org.cn/whl/paddlenlp.html
12+
13+
# 安装 paddlepaddle-gpu nightly版本
14+
pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/
15+
16+
#安装其他依赖:
17+
pip install -r slm/pipelines/examples/contrastive_training/requirements.txt
1318
```
1419

20+
1521
下载 DuReader-Retrieval 中文数据集:
1622
```
1723
cd data
@@ -45,7 +51,7 @@ python train.py --do_train \
4551
python -m paddle.distributed.launch --gpus "0,1,2,3" train.py --do_train \
4652
--model_name_or_path rocketqa-zh-base-query-encoder \
4753
--output_dir rocketqa-zh-base-query-encoder-duretrieval \
48-
--train_data ./data/dual.train.json \
54+
--train_data ./data/dureader_dual.train.jsonl \
4955
--overwrite_output_dir \
5056
--fine_tune_type sft \
5157
--sentence_pooling_method cls \
@@ -132,7 +138,7 @@ python evaluation/benchmarks.py --model_type bert \
132138
--passage_max_length 512 \
133139
```
134140
可配置参数包括:
135-
- `model_type`: 模型的类似,可选 bert 或 roberta 等等
141+
- `model_type`: 模型的类型,可选 bert 或 roberta 等等
136142
- `query_model`: query 向量模型的路径
137143
- `passage_model`: passage 向量模型的路径
138144
- `query_max_length`: query 的最大长度
@@ -179,7 +185,7 @@ python -u evaluation/eval_mteb.py \
179185
- `add_bos_token`:是否添加起始符,0表示不添加,1表示添加
180186
- `add_eos_token`:是否添加结束符,0表示不添加,1表示添加
181187

182-
# MTEB 评估
188+
## MTEB 评估
183189
[MTEB](https://github.com/embeddings-benchmark/mteb)
184190
是一个大规模文本嵌入评测基准,包含了丰富的向量检索评估任务和数据集。
185191
本仓库主要面向其中的英文检索任务(Retrieval),并额外支持针对 MSMARCO-Title 的评估。
@@ -261,6 +267,39 @@ MTEB-Retrieval 数据集, NDCG@10分数:
261267
| LLARA-passage | 52.48 | 47.51 | 26.13 | 37.26 | 44.12 | 81.09 | 43.98 | 69.17 | 45.49 | 37.07 | 61.76 | 82.29 | 17.30 | 76.07 | 36.73 | 81.30 |
262268

263269

270+
## 压缩
271+
272+
### 模型删层
273+
模型剪枝脚本 `shortgpt_prune.py`,用于评估并移除大语言模型中重要性较低的层,以生成一个更小、更高效的模型。该脚本采用“块影响”度量来计算层的重要性,并直接在内存中完成剪枝和保存,流程高效。
274+
275+
#### 使用方法
276+
277+
通过以下命令执行剪枝脚本。可指定原始模型、输出路径、要剪枝的层数以及模型中transformer层的路径。
278+
279+
以repllama-v1-7b-lora-passage为例:
280+
```bash
281+
python shortgpt_prune.py \
282+
--model_name_or_path castorini/repllama-v1-7b-lora-passage \
283+
--output_model_path ./pruned-repllama-v1-7b-lora-passage \
284+
--n_prune_layers 6 \
285+
--layers_path "llama.layers"
286+
```
287+
288+
以NV-Embed-v1为例:
289+
```bash
290+
python shortgpt_prune.py \
291+
--model_name_or_path nvidia/NV-Embed-v1 \
292+
--output_model_path /pruned-NV-Embed-v1_pruned_26 \
293+
--n_prune_layers 6 \
294+
--layers_path "layers"
295+
```
296+
可配置参数包括:
297+
- `--model_name_or_path`: 原始模型的名称或本地路径。
298+
- `--output_model_path`: 剪枝后模型的保存路径。
299+
- `--n_prune_layers`: 希望移除的层数。脚本会自动找出最不重要的N层。
300+
- `--layers_path`: 模型对象中指向transformer层列表的点分隔路径(例如repllama为`"llama.layers"`, llama为`"model.layers"`)。
301+
302+
可用output_model_path路径中的模型跑评估[评估部分的代码](#评估)
264303

265304
## Reference
266305

@@ -281,3 +320,5 @@ MTEB-Retrieval 数据集, NDCG@10分数:
281320
[8] Yingqi Qu, Yuchen Ding, Jing Liu, Kai Liu, Ruiyang Ren, Wayne Xin Zhao, Daxiang Dong, Hua Wu, Haifeng Wang: RocketQA: An Optimized Training Approach to Dense Passage Retrieval for Open-Domain Question Answering. NAACL 2021
282321

283322
[9] Ruiyang Ren, Yingqi Qu, Jing Liu, Wayne Xin Zhao, Qiaoqiao She, Hua Wu, Haifeng Wang, Ji-Rong Wen: RocketQAv2: A Joint Training Method for Dense Passage Retrieval and Passage Re-ranking. EMNLP 2021
323+
324+
[10] Xin Men, Mingyu Xu, Qingyu Zhang, Bingning Wang, Hongyu Lin, Yaojie Lu, Xianpei Han, Weipeng Chen: Shortgpt: Layers in large language models are more redundant than you expect. ACL Findings 2025

slm/pipelines/examples/contrastive_training/evaluation/benchmarks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from datasets import load_dataset
2323
from mteb.abstasks import AbsTaskRetrieval
24-
from prediction import Eval_modle
24+
from prediction import Eval_model
2525

2626
csv.field_size_limit(500 * 1024 * 1024)
2727

@@ -51,7 +51,7 @@ def __init__(
5151
pooling_mode="mean_tokens",
5252
**kwargs,
5353
):
54-
self.query_model = Eval_modle(
54+
self.query_model = Eval_model(
5555
model=query_model,
5656
max_seq_len=max_seq_len,
5757
batch_size=batch_size,

slm/pipelines/examples/contrastive_training/evaluation/eval_mteb.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
class MSMARCOTITLE(AbsTaskRetrieval):
3030
metadata = TaskMetadata(
3131
dataset={
32-
"corpus_path": "Tevatron/msmarco-passage-corpus",
32+
"corpus_path": "Tevatron/msmarco-passage-corpus-new",
3333
"path": "mteb/msmarco",
34-
"revision": "c5a29a104738b98a9e76336939199e264163d4a0",
34+
"revision": "c5a29a104738b98a9e76336939199e264163d4a0",
3535
},
3636
name="MSMARCOTITLE",
3737
description="MS MARCO is a collection of datasets focused on deep learning in search",
@@ -53,6 +53,9 @@ class MSMARCOTITLE(AbsTaskRetrieval):
5353
bibtex_citation=None,
5454
n_samples=None,
5555
avg_character_length=None,
56+
modalities = ["text"],
57+
sample_creation = "created",
58+
descriptive_stats = {}
5659
)
5760

5861
def load_data(self, **kwargs):

slm/pipelines/examples/contrastive_training/evaluation/eval_mteb.sh

Lines changed: 113 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,62 +12,114 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
#!/bin/bash
1516

16-
for task in "ArguAna" "ClimateFEVER" "DBPedia" "FEVER" "FiQA2018" "HotpotQA" "MSMARCO" "NFCorpus" "NQ" "QuoraRetrieval" "SCIDOCS" "SciFact" "Touche2020" "TRECCOVID" "CQADupstackAndroidRetrieval" "CQADupstackEnglishRetrieval" "CQADupstackGamingRetrieval" "CQADupstackGisRetrieval" "CQADupstackMathematicaRetrieval" "CQADupstackPhysicsRetrieval" "CQADupstackProgrammersRetrieval" "CQADupstackStatsRetrieval" "CQADupstackTexRetrieval" "CQADupstackUnixRetrieval" "CQADupstackWebmastersRetrieval" "CQADupstackWordpressRetrieval" "MSMARCOTITLE"
17-
do
17+
# --- Script Configuration ---
18+
# Exit immediately if a command exits with a non-zero status.
19+
set -e
1820

19-
# 1. RocketQA V1
20-
python3.10 -u eval_mteb.py \
21-
--corpus_model_name_or_path rocketqa-en-base-v1/passage_model \
22-
--query_model_name_or_path rocketqa-en-base-v1/query_model \
21+
# Define the list of all tasks (datasets) to be evaluated.
22+
# TASKS=(
23+
# "ArguAna" "ClimateFEVER" "DBPedia" "FEVER" "FiQA2018" "HotpotQA" "MSMARCO" "NFCorpus" "NQ" "QuoraRetrieval"
24+
# "SCIDOCS" "SciFact" "Touche2020" "TRECCOVID" "CQADupstackAndroidRetrieval" "CQADupstackEnglishRetrieval"
25+
# "CQADupstackGamingRetrieval" "CQADupstackGisRetrieval" "CQADupstackMathematicaRetrieval" "CQADupstackPhysicsRetrieval"
26+
# "CQADupstackProgrammersRetrieval" "CQADupstackStatsRetrieval" "CQADupstackTexRetrieval" "CQADupstackUnixRetrieval"
27+
# "CQADupstackWebmastersRetrieval" "CQADupstackWordpressRetrieval" "MSMARCOTITLE"
28+
# )
29+
30+
TASKS=("ArguAna" "SCIDOCS" "FEVER")
31+
32+
33+
# You can uncomment the models you wish to evaluate.
34+
# MODELS_TO_RUN=("RocketQA-V1" "RocketQA-V2" "BGE" "RepLLaMA" "NV-Embed-v1" "BGE-EN-ICL" "LLARA-passage")
35+
MODELS_TO_RUN=("BGE")
36+
37+
38+
# ===================================================================================
39+
# 🚀 1. RocketQA V1
40+
# ===================================================================================
41+
if [[ " ${MODELS_TO_RUN[*]} " =~ " RocketQA-V1 " ]]; then
42+
echo "===== Running Evaluation for Model: RocketQA V1 ====="
43+
for task in "${TASKS[@]}"; do
44+
echo "--- Task: $task ---"
45+
python3.10 -u evaluation/eval_mteb.py \
46+
--corpus_model_name_or_path rocketqa-v1-marco-para-encoder \
47+
--query_model_name_or_path rocketqa-v1-marco-query-encoder \
2348
--model_flag RocketQA-V1 \
2449
--output_folder en_results/rocketqa-en-base-v1 \
2550
--task_name "$task" \
26-
--task_split $(if [[ "$task" == *"MSMARCO"* ]]; then echo "dev"; else echo "test"; fi) \
51+
--task_split $([[ "$task" == *"MSMARCO"* ]] && echo "dev" || echo "test") \
2752
--query_instruction "" \
2853
--document_instruction "" \
2954
--max_seq_length 512 \
3055
--eval_batch_size 32 \
3156
--dtype "float32" \
3257
--padding_side right \
3358
--pooling_method "cls"
59+
done
60+
fi
61+
3462

35-
# 2. RocketQA V2
36-
python3.10 -u eval_mteb.py \
37-
--corpus_model_name_or_path rocketqa-en-base-v2/passage_model \
38-
--query_model_name_or_path rocketqa-en-base-v2/query_model \
63+
# ===================================================================================
64+
# 🚀 2. RocketQA V2
65+
# ===================================================================================
66+
if [[ " ${MODELS_TO_RUN[*]} " =~ " RocketQA-V2 " ]]; then
67+
echo "===== Running Evaluation for Model: RocketQA V2 ====="
68+
for task in "${TASKS[@]}"; do
69+
echo "--- Task: $task ---"
70+
python3.10 -u evaluation/eval_mteb.py \
71+
--corpus_model_name_or_path rocketqav2-en-marco-para-encoder \
72+
--query_model_name_or_path rocketqav2-en-marco-query-encoder \
3973
--model_flag RocketQA-V2 \
4074
--output_folder en_results/rocketqa-en-base-v2 \
4175
--task_name "$task" \
42-
--task_split $(if [[ "$task" == *"MSMARCO"* ]]; then echo "dev"; else echo "test"; fi) \
76+
--task_split $([[ "$task" == *"MSMARCO"* ]] && echo "dev" || echo "test") \
4377
--query_instruction "" \
4478
--document_instruction "" \
4579
--max_seq_length 512 \
4680
--eval_batch_size 128 \
4781
--dtype "float32" \
4882
--padding_side right \
4983
--pooling_method "cls"
84+
done
85+
fi
86+
5087

51-
# 3. BGE
52-
python3.10 eval_mteb.py \
88+
# ===================================================================================
89+
# 🎯 3. BGE (BAAI/bge-large-en-v1.5)
90+
# ===================================================================================
91+
if [[ " ${MODELS_TO_RUN[*]} " =~ " BGE " ]]; then
92+
echo "===== Running Evaluation for Model: BGE (bge-large-en-v1.5) ====="
93+
for task in "${TASKS[@]}"; do
94+
echo "--- Task: $task ---"
95+
python3.10 evaluation/eval_mteb.py \
5396
--base_model_name_or_path BAAI/bge-large-en-v1.5 \
54-
--output_folder en_results/bge-large-en-v1.5 \
97+
--output_folder en_results/bge-large-en-v1.5_2 \
5598
--task_name "$task" \
56-
--task_split $(if [[ "$task" == *"MSMARCO"* ]]; then echo "dev"; else echo "test"; fi) \
99+
--task_split $([[ "$task" == *"MSMARCO"* ]] && echo "dev" || echo "test") \
57100
--document_instruction 'Represent this sentence for searching relevant passages: ' \
58101
--pooling_method mean \
59102
--max_seq_length 512 \
60103
--eval_batch_size 32 \
61104
--padding_side right \
62105
--add_bos_token 0 \
63-
--add_eos_token 0
106+
--add_eos_token 0
107+
done
108+
fi
64109

65-
# 4. RepLLaMA
66-
python3.10 eval_mteb.py \
110+
111+
# ===================================================================================
112+
# 🦙 4. RepLLaMA
113+
# ===================================================================================
114+
if [[ " ${MODELS_TO_RUN[*]} " =~ " RepLLaMA " ]]; then
115+
echo "===== Running Evaluation for Model: RepLLaMA ====="
116+
for task in "${TASKS[@]}"; do
117+
echo "--- Task: $task ---"
118+
python3.10 evaluation/eval_mteb.py \
67119
--base_model_name_or_path castorini/repllama-v1-7b-lora-passage \
68120
--output_folder en_results/repllama-v1-7b-lora-passage \
69121
--task_name "$task" \
70-
--task_split $(if [[ "$task" == *"MSMARCO"* ]]; then echo "dev"; else echo "test"; fi) \
122+
--task_split $([[ "$task" == *"MSMARCO"* ]] && echo "dev" || echo "test") \
71123
--query_instruction 'query: ' \
72124
--document_instruction 'passage: ' \
73125
--pooling_method last \
@@ -76,41 +128,72 @@ do
76128
--padding_side right \
77129
--add_bos_token 0 \
78130
--add_eos_token 1
131+
done
132+
fi
133+
79134

80-
# 5. NV-Embed-v1
81-
python3.10 eval_mteb.py \
135+
# ===================================================================================
136+
# Nvidia 5. NV-Embed-v1
137+
# ===================================================================================
138+
if [[ " ${MODELS_TO_RUN[*]} " =~ " NV-Embed-v1 " ]]; then
139+
echo "===== Running Evaluation for Model: NV-Embed-v1 ====="
140+
for task in "${TASKS[@]}"; do
141+
echo "--- Task: $task ---"
142+
python3.10 evaluation/eval_mteb.py \
82143
--base_model_name_or_path nvidia/NV-Embed-v1 \
83144
--output_folder en_results/nv-embed-v1 \
84145
--query_instruction "Given a claim, find documents that refute the claim" \
85146
--task_name "$task" \
86-
--task_split $(if [[ "$task" == *"MSMARCO"* ]]; then echo "dev"; else echo "test"; fi) \
147+
--task_split $([[ "$task" == *"MSMARCO"* ]] && echo "dev" || echo "test") \
87148
--eval_batch_size 8
149+
done
150+
fi
151+
88152

89-
# 6. BGE-EN-ICL
90-
python3.10 eval_mteb.py \
153+
# ===================================================================================
154+
# 🎯 6. BGE-EN-ICL
155+
# ===================================================================================
156+
if [[ " ${MODELS_TO_RUN[*]} " =~ " BGE-EN-ICL " ]]; then
157+
echo "===== Running Evaluation for Model: BGE-EN-ICL ====="
158+
for task in "${TASKS[@]}"; do
159+
echo "--- Task: $task ---"
160+
python3.10 evaluation/eval_mteb.py \
91161
--base_model_name_or_path BAAI/bge-en-icl \
92162
--output_folder en_results/bge-en-icl \
93163
--task_name "$task" \
94-
--task_split $(if [[ "$task" == *"MSMARCO"* ]]; then echo "dev"; else echo "test"; fi) \
164+
--task_split $([[ "$task" == *"MSMARCO"* ]] && echo "dev" || echo "test") \
95165
--query_instruction $'<instruct> Given a scientific claim, retrieve documents that support or refute the claim.\n<query>' \
96166
--max_seq_length 512 \
97167
--eval_batch_size 32 \
98168
--dtype "float32" \
99169
--padding_side left \
100170
--add_bos_token 1 \
101171
--add_eos_token 1
172+
done
173+
fi
102174

103-
# 7. LLARA-passage
104-
python3.10 eval_mteb.py \
175+
176+
# ===================================================================================
177+
# 🦙 7. LLARA-passage
178+
# ===================================================================================
179+
if [[ " ${MODELS_TO_RUN[*]} " =~ " LLARA-passage " ]]; then
180+
echo "===== Running Evaluation for Model: LLARA-passage ====="
181+
for task in "${TASKS[@]}"; do
182+
echo "--- Task: $task ---"
183+
python3.10 evaluation/eval_mteb.py \
105184
--base_model_name_or_path BAAI/LLARA-passage \
106185
--output_folder en_results/llara-passage \
107186
--task_name "$task" \
108-
--task_split $(if [[ "$task" == *"MSMARCO"* ]]; then echo "dev"; else echo "test"; fi) \
187+
--task_split $([[ "$task" == *"MSMARCO"* ]] && echo "dev" || echo "test") \
109188
--eval_batch_size 8 \
110189
--pooling_method last_8 \
111190
--model_flag llara \
112191
--add_bos_token 1 \
113192
--add_eos_token 0 \
114193
--max_seq_length 532
194+
done
195+
fi
196+
197+
115198

116-
done
199+
echo "All specified evaluations are complete."

0 commit comments

Comments
 (0)