Skip to content

feat(shortgpt): add new script and polish existing code (typos, paths) #10882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,4 @@ autogen/
#fp8
ops/csrc/fp8/deep_gemm/include/cutlass
ops/csrc/fp8/deep_gemm/include/cute
.ccls-cache
.ccls-cache
63 changes: 52 additions & 11 deletions slm/pipelines/examples/contrastive_training/README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
# 向量检索模型训练

推荐安装 gpu 版本的[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/conda/linux-conda.html),以 cuda12.3的 paddle 为例,安装命令如下:
推荐安装 gpu 版本的[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/conda/linux-conda.html),以 cuda11.8的 paddle 为例,安装命令如下:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

推荐使用更高版本的paddle,低版本保留时间不会太长

Copy link
Author

@tianyumyum tianyumyum Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

适配更高cuda版本的话,请问具体哪个版本呢,nightly or stable。目前我使用服务器中的cuda version是12.1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以考虑当前的cuda 12.6和12.9


```
conda install nccl -c conda-forge
conda install paddlepaddle-gpu==3.0.0rc1 -i https://www.paddlepaddle.org.cn/packages/stable/cu123/ -c conda-forge
```
安装其他依赖:
```
pip install git+https://github.com/PaddlePaddle/PaddleNLP.git@develop
pip install -r requirements.txt
# 创建一个名为 paddle_env 的新环境,并激活
conda create --name paddle_env python=3.10
conda activate paddle_env

# 安装 paddlenlp develop版本
pip install --pre --upgrade paddlenlp -f https://www.paddlepaddle.org.cn/whl/paddlenlp.html

# 安装 paddlepaddle-gpu nightly版本
pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/

#安装其他依赖:
pip install -r slm/pipelines/examples/contrastive_training/requirements.txt
```


下载 DuReader-Retrieval 中文数据集:
```
cd data
Expand Down Expand Up @@ -45,7 +51,7 @@ python train.py --do_train \
python -m paddle.distributed.launch --gpus "0,1,2,3" train.py --do_train \
--model_name_or_path rocketqa-zh-base-query-encoder \
--output_dir rocketqa-zh-base-query-encoder-duretrieval \
--train_data ./data/dual.train.json \
--train_data ./data/dureader_dual.train.jsonl \
--overwrite_output_dir \
--fine_tune_type sft \
--sentence_pooling_method cls \
Expand Down Expand Up @@ -132,7 +138,7 @@ python evaluation/benchmarks.py --model_type bert \
--passage_max_length 512 \
```
可配置参数包括:
- `model_type`: 模型的类似,可选 bert 或 roberta 等等
- `model_type`: 模型的类型,可选 bert 或 roberta 等等
- `query_model`: query 向量模型的路径
- `passage_model`: passage 向量模型的路径
- `query_max_length`: query 的最大长度
Expand Down Expand Up @@ -179,7 +185,7 @@ python -u evaluation/eval_mteb.py \
- `add_bos_token`:是否添加起始符,0表示不添加,1表示添加
- `add_eos_token`:是否添加结束符,0表示不添加,1表示添加

# MTEB 评估
## MTEB 评估
[MTEB](https://github.com/embeddings-benchmark/mteb)
是一个大规模文本嵌入评测基准,包含了丰富的向量检索评估任务和数据集。
本仓库主要面向其中的英文检索任务(Retrieval),并额外支持针对 MSMARCO-Title 的评估。
Expand Down Expand Up @@ -261,6 +267,39 @@ MTEB-Retrieval 数据集, NDCG@10分数:
| LLARA-passage | 52.48 | 47.51 | 26.13 | 37.26 | 44.12 | 81.09 | 43.98 | 69.17 | 45.49 | 37.07 | 61.76 | 82.29 | 17.30 | 76.07 | 36.73 | 81.30 |


## 压缩

### 模型删层
模型剪枝脚本 `shortgpt_prune.py`,用于评估并移除大语言模型中重要性较低的层,以生成一个更小、更高效的模型。该脚本采用“块影响”度量来计算层的重要性,并直接在内存中完成剪枝和保存,流程高效。

#### 使用方法

通过以下命令执行剪枝脚本。可指定原始模型、输出路径、要剪枝的层数以及模型中transformer层的路径。

以repllama-v1-7b-lora-passage为例:
```bash
python shortgpt_prune.py \
--model_name_or_path castorini/repllama-v1-7b-lora-passage \
--output_model_path ./pruned-repllama-v1-7b-lora-passage \
--n_prune_layers 6 \
--layers_path "llama.layers"
```

以NV-Embed-v1为例:
```bash
python shortgpt_prune.py \
--model_name_or_path nvidia/NV-Embed-v1 \
--output_model_path /pruned-NV-Embed-v1_pruned_26 \
--n_prune_layers 6 \
--layers_path "layers"
```
可配置参数包括:
- `--model_name_or_path`: 原始模型的名称或本地路径。
- `--output_model_path`: 剪枝后模型的保存路径。
- `--n_prune_layers`: 希望移除的层数。脚本会自动找出最不重要的N层。
- `--layers_path`: 模型对象中指向transformer层列表的点分隔路径(例如repllama为`"llama.layers"`, llama为`"model.layers"`)。

可用output_model_path路径中的模型跑评估[评估部分的代码](#评估)

## Reference

Expand All @@ -281,3 +320,5 @@ MTEB-Retrieval 数据集, NDCG@10分数:
[8] Yingqi Qu, Yuchen Ding, Jing Liu, Kai Liu, Ruiyang Ren, Wayne Xin Zhao, Daxiang Dong, Hua Wu, Haifeng Wang: RocketQA: An Optimized Training Approach to Dense Passage Retrieval for Open-Domain Question Answering. NAACL 2021

[9] Ruiyang Ren, Yingqi Qu, Jing Liu, Wayne Xin Zhao, Qiaoqiao She, Hua Wu, Haifeng Wang, Ji-Rong Wen: RocketQAv2: A Joint Training Method for Dense Passage Retrieval and Passage Re-ranking. EMNLP 2021

[10] Xin Men, Mingyu Xu, Qingyu Zhang, Bingning Wang, Hongyu Lin, Yaojie Lu, Xianpei Han, Weipeng Chen: Shortgpt: Layers in large language models are more redundant than you expect. ACL Findings 2025
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from datasets import load_dataset
from mteb.abstasks import AbsTaskRetrieval
from prediction import Eval_modle
from prediction import Eval_model

csv.field_size_limit(500 * 1024 * 1024)

Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
pooling_mode="mean_tokens",
**kwargs,
):
self.query_model = Eval_modle(
self.query_model = Eval_model(
model=query_model,
max_seq_len=max_seq_len,
batch_size=batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
class MSMARCOTITLE(AbsTaskRetrieval):
metadata = TaskMetadata(
dataset={
"corpus_path": "Tevatron/msmarco-passage-corpus",
"corpus_path": "Tevatron/msmarco-passage-corpus-new",
"path": "mteb/msmarco",
"revision": "c5a29a104738b98a9e76336939199e264163d4a0",
},
Expand All @@ -53,6 +53,9 @@ class MSMARCOTITLE(AbsTaskRetrieval):
bibtex_citation=None,
n_samples=None,
avg_character_length=None,
modalities=["text"],
sample_creation="created",
descriptive_stats={},
)

def load_data(self, **kwargs):
Expand Down
143 changes: 113 additions & 30 deletions slm/pipelines/examples/contrastive_training/evaluation/eval_mteb.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,62 +12,114 @@
# See the License for the specific language governing permissions and
# limitations under the License.

#!/bin/bash

for task in "ArguAna" "ClimateFEVER" "DBPedia" "FEVER" "FiQA2018" "HotpotQA" "MSMARCO" "NFCorpus" "NQ" "QuoraRetrieval" "SCIDOCS" "SciFact" "Touche2020" "TRECCOVID" "CQADupstackAndroidRetrieval" "CQADupstackEnglishRetrieval" "CQADupstackGamingRetrieval" "CQADupstackGisRetrieval" "CQADupstackMathematicaRetrieval" "CQADupstackPhysicsRetrieval" "CQADupstackProgrammersRetrieval" "CQADupstackStatsRetrieval" "CQADupstackTexRetrieval" "CQADupstackUnixRetrieval" "CQADupstackWebmastersRetrieval" "CQADupstackWordpressRetrieval" "MSMARCOTITLE"
do
# --- Script Configuration ---
# Exit immediately if a command exits with a non-zero status.
set -e

# 1. RocketQA V1
python3.10 -u eval_mteb.py \
--corpus_model_name_or_path rocketqa-en-base-v1/passage_model \
--query_model_name_or_path rocketqa-en-base-v1/query_model \
# Define the list of all tasks (datasets) to be evaluated.
# TASKS=(
# "ArguAna" "ClimateFEVER" "DBPedia" "FEVER" "FiQA2018" "HotpotQA" "MSMARCO" "NFCorpus" "NQ" "QuoraRetrieval"
# "SCIDOCS" "SciFact" "Touche2020" "TRECCOVID" "CQADupstackAndroidRetrieval" "CQADupstackEnglishRetrieval"
# "CQADupstackGamingRetrieval" "CQADupstackGisRetrieval" "CQADupstackMathematicaRetrieval" "CQADupstackPhysicsRetrieval"
# "CQADupstackProgrammersRetrieval" "CQADupstackStatsRetrieval" "CQADupstackTexRetrieval" "CQADupstackUnixRetrieval"
# "CQADupstackWebmastersRetrieval" "CQADupstackWordpressRetrieval" "MSMARCOTITLE"
# )

TASKS=("ArguAna" "SCIDOCS" "FEVER")


# You can uncomment the models you wish to evaluate.
# MODELS_TO_RUN=("RocketQA-V1" "RocketQA-V2" "BGE" "RepLLaMA" "NV-Embed-v1" "BGE-EN-ICL" "LLARA-passage")
MODELS_TO_RUN=("BGE")


# ===================================================================================
# 🚀 1. RocketQA V1
# ===================================================================================
if [[ " ${MODELS_TO_RUN[*]} " =~ " RocketQA-V1 " ]]; then
echo "===== Running Evaluation for Model: RocketQA V1 ====="
for task in "${TASKS[@]}"; do
echo "--- Task: $task ---"
python3.10 -u evaluation/eval_mteb.py \
--corpus_model_name_or_path rocketqa-v1-marco-para-encoder \
--query_model_name_or_path rocketqa-v1-marco-query-encoder \
--model_flag RocketQA-V1 \
--output_folder en_results/rocketqa-en-base-v1 \
--task_name "$task" \
--task_split $(if [[ "$task" == *"MSMARCO"* ]]; then echo "dev"; else echo "test"; fi) \
--task_split $([[ "$task" == *"MSMARCO"* ]] && echo "dev" || echo "test") \
--query_instruction "" \
--document_instruction "" \
--max_seq_length 512 \
--eval_batch_size 32 \
--dtype "float32" \
--padding_side right \
--pooling_method "cls"
done
fi


# 2. RocketQA V2
python3.10 -u eval_mteb.py \
--corpus_model_name_or_path rocketqa-en-base-v2/passage_model \
--query_model_name_or_path rocketqa-en-base-v2/query_model \
# ===================================================================================
# 🚀 2. RocketQA V2
# ===================================================================================
if [[ " ${MODELS_TO_RUN[*]} " =~ " RocketQA-V2 " ]]; then
echo "===== Running Evaluation for Model: RocketQA V2 ====="
for task in "${TASKS[@]}"; do
echo "--- Task: $task ---"
python3.10 -u evaluation/eval_mteb.py \
--corpus_model_name_or_path rocketqav2-en-marco-para-encoder \
--query_model_name_or_path rocketqav2-en-marco-query-encoder \
--model_flag RocketQA-V2 \
--output_folder en_results/rocketqa-en-base-v2 \
--task_name "$task" \
--task_split $(if [[ "$task" == *"MSMARCO"* ]]; then echo "dev"; else echo "test"; fi) \
--task_split $([[ "$task" == *"MSMARCO"* ]] && echo "dev" || echo "test") \
--query_instruction "" \
--document_instruction "" \
--max_seq_length 512 \
--eval_batch_size 128 \
--dtype "float32" \
--padding_side right \
--pooling_method "cls"
done
fi


# 3. BGE
python3.10 eval_mteb.py \
# ===================================================================================
# 🎯 3. BGE (BAAI/bge-large-en-v1.5)
# ===================================================================================
if [[ " ${MODELS_TO_RUN[*]} " =~ " BGE " ]]; then
echo "===== Running Evaluation for Model: BGE (bge-large-en-v1.5) ====="
for task in "${TASKS[@]}"; do
echo "--- Task: $task ---"
python3.10 evaluation/eval_mteb.py \
--base_model_name_or_path BAAI/bge-large-en-v1.5 \
--output_folder en_results/bge-large-en-v1.5 \
--output_folder en_results/bge-large-en-v1.5_2 \
--task_name "$task" \
--task_split $(if [[ "$task" == *"MSMARCO"* ]]; then echo "dev"; else echo "test"; fi) \
--task_split $([[ "$task" == *"MSMARCO"* ]] && echo "dev" || echo "test") \
--document_instruction 'Represent this sentence for searching relevant passages: ' \
--pooling_method mean \
--max_seq_length 512 \
--eval_batch_size 32 \
--padding_side right \
--add_bos_token 0 \
--add_eos_token 0
--add_eos_token 0
done
fi

# 4. RepLLaMA
python3.10 eval_mteb.py \

# ===================================================================================
# 🦙 4. RepLLaMA
# ===================================================================================
if [[ " ${MODELS_TO_RUN[*]} " =~ " RepLLaMA " ]]; then
echo "===== Running Evaluation for Model: RepLLaMA ====="
for task in "${TASKS[@]}"; do
echo "--- Task: $task ---"
python3.10 evaluation/eval_mteb.py \
--base_model_name_or_path castorini/repllama-v1-7b-lora-passage \
--output_folder en_results/repllama-v1-7b-lora-passage \
--task_name "$task" \
--task_split $(if [[ "$task" == *"MSMARCO"* ]]; then echo "dev"; else echo "test"; fi) \
--task_split $([[ "$task" == *"MSMARCO"* ]] && echo "dev" || echo "test") \
--query_instruction 'query: ' \
--document_instruction 'passage: ' \
--pooling_method last \
Expand All @@ -76,41 +128,72 @@ do
--padding_side right \
--add_bos_token 0 \
--add_eos_token 1
done
fi


# 5. NV-Embed-v1
python3.10 eval_mteb.py \
# ===================================================================================
# Nvidia 5. NV-Embed-v1
# ===================================================================================
if [[ " ${MODELS_TO_RUN[*]} " =~ " NV-Embed-v1 " ]]; then
echo "===== Running Evaluation for Model: NV-Embed-v1 ====="
for task in "${TASKS[@]}"; do
echo "--- Task: $task ---"
python3.10 evaluation/eval_mteb.py \
--base_model_name_or_path nvidia/NV-Embed-v1 \
--output_folder en_results/nv-embed-v1 \
--query_instruction "Given a claim, find documents that refute the claim" \
--task_name "$task" \
--task_split $(if [[ "$task" == *"MSMARCO"* ]]; then echo "dev"; else echo "test"; fi) \
--task_split $([[ "$task" == *"MSMARCO"* ]] && echo "dev" || echo "test") \
--eval_batch_size 8
done
fi


# 6. BGE-EN-ICL
python3.10 eval_mteb.py \
# ===================================================================================
# 🎯 6. BGE-EN-ICL
# ===================================================================================
if [[ " ${MODELS_TO_RUN[*]} " =~ " BGE-EN-ICL " ]]; then
echo "===== Running Evaluation for Model: BGE-EN-ICL ====="
for task in "${TASKS[@]}"; do
echo "--- Task: $task ---"
python3.10 evaluation/eval_mteb.py \
--base_model_name_or_path BAAI/bge-en-icl \
--output_folder en_results/bge-en-icl \
--task_name "$task" \
--task_split $(if [[ "$task" == *"MSMARCO"* ]]; then echo "dev"; else echo "test"; fi) \
--task_split $([[ "$task" == *"MSMARCO"* ]] && echo "dev" || echo "test") \
--query_instruction $'<instruct> Given a scientific claim, retrieve documents that support or refute the claim.\n<query>' \
--max_seq_length 512 \
--eval_batch_size 32 \
--dtype "float32" \
--padding_side left \
--add_bos_token 1 \
--add_eos_token 1
done
fi

# 7. LLARA-passage
python3.10 eval_mteb.py \

# ===================================================================================
# 🦙 7. LLARA-passage
# ===================================================================================
if [[ " ${MODELS_TO_RUN[*]} " =~ " LLARA-passage " ]]; then
echo "===== Running Evaluation for Model: LLARA-passage ====="
for task in "${TASKS[@]}"; do
echo "--- Task: $task ---"
python3.10 evaluation/eval_mteb.py \
--base_model_name_or_path BAAI/LLARA-passage \
--output_folder en_results/llara-passage \
--task_name "$task" \
--task_split $(if [[ "$task" == *"MSMARCO"* ]]; then echo "dev"; else echo "test"; fi) \
--task_split $([[ "$task" == *"MSMARCO"* ]] && echo "dev" || echo "test") \
--eval_batch_size 8 \
--pooling_method last_8 \
--model_flag llara \
--add_bos_token 1 \
--add_eos_token 0 \
--max_seq_length 532
done
fi



done
echo "All specified evaluations are complete."
Loading
Loading