Skip to content

Commit 319c56e

Browse files
authored
Add CrossEncoder and Integrate it into pipelines (#3196)
* Add CrossEncoder and Integrate it into pipelines * Add CrossEncoder README.md
1 parent ea28142 commit 319c56e

33 files changed

+1706
-42
lines changed
Lines changed: 399 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,399 @@
1+
2+
**目录**
3+
4+
* [背景介绍](#背景介绍)
5+
* [CrossEncoder](#CrossEncoder)
6+
* [1. 技术方案和评估指标](#技术方案)
7+
* [2. 环境依赖](#环境依赖)
8+
* [3. 代码结构](#代码结构)
9+
* [4. 数据准备](#数据准备)
10+
* [5. 模型训练](#模型训练)
11+
* [6. 评估](#开始评估)
12+
* [7. 预测](#预测)
13+
* [8. 部署](#部署)
14+
15+
<a name="背景介绍"></a>
16+
17+
# 背景介绍
18+
19+
基于RocketQA的CrossEncoder训练的单塔模型,该模型用于搜索的排序阶段,对召回的结果进行重新排序的作用。
20+
21+
22+
<a name="CrossEncoder"></a>
23+
24+
# CrossEncoder
25+
26+
<a name="技术方案"></a>
27+
28+
## 1. 技术方案和评估指标
29+
30+
### 技术方案
31+
32+
加载基于ERNIE 3.0训练过的RocketQA单塔CrossEncoder模型。
33+
34+
35+
### 评估指标
36+
37+
(1)采用 AUC 指标来评估排序模型的排序效果。
38+
39+
**效果评估**
40+
41+
| 训练方式 | 模型 | AUC |
42+
| ------------ | ------------ |------------ |
43+
| pairwise| ERNIE-Gram |0.801 |
44+
| CrossEncoder | rocketqa-base-cross-encoder |**0.835** |
45+
46+
<a name="环境依赖"></a>
47+
48+
## 2. 环境依赖和安装说明
49+
50+
**环境依赖**
51+
52+
* python >= 3.7
53+
* paddlepaddle >= 2.3.7
54+
* paddlenlp >= 2.3
55+
* pandas >= 0.25.1
56+
* scipy >= 1.3.1
57+
58+
<a name="代码结构"></a>
59+
60+
## 3. 代码结构
61+
62+
以下是本项目主要代码结构及说明:
63+
64+
```
65+
ernie_matching/
66+
├── deply # 部署
67+
├── cpp
68+
├── rpc_client.py # RPC 客户端的bash脚本
69+
├── http_client.py # http 客户端的bash文件
70+
└── start_server.sh # 启动C++服务的脚本
71+
└── python
72+
├── deploy.sh # 预测部署bash脚本
73+
├── config_nlp.yml # Pipeline 的配置文件
74+
├── web_service.py # Pipeline 服务端的脚本
75+
├── rpc_client.py # Pipeline RPC客户端的脚本
76+
└── predict.py # python 预测部署示例
77+
|—— scripts
78+
├── export_model.sh # 动态图参数导出静态图参数的bash文件
79+
├── export_to_serving.sh # 导出 Paddle Serving 模型格式的bash文件
80+
├── train_ce.sh # 匹配模型训练的bash文件
81+
├── evaluate_ce.sh # 评估验证文件bash脚本
82+
├── predict_ce.sh # 匹配模型预测脚本的bash文件
83+
├── export_model.py # 动态图参数导出静态图参数脚本
84+
├── export_to_serving.py # 导出 Paddle Serving 模型格式的脚本
85+
├── data.py # 训练样本的转换逻辑
86+
├── train_ce.py # 模型训练脚本
87+
├── evaluate.py # 评估验证文件
88+
├── predict.py # Pair-wise 模型预测脚本,输出文本对是相似度
89+
90+
```
91+
92+
<a name="数据准备"></a>
93+
94+
## 4. 数据准备
95+
96+
### 数据集说明
97+
98+
样例数据如下:
99+
```
100+
(小学数学教材比较) 关键词:新加坡 新加坡与中国数学教材的特色比较数学教材,教材比较,问题解决 0
101+
徐慧新疆肿瘤医院 头颈部非霍奇金淋巴瘤扩散加权成像ADC值与Ki-67表达相关性分析淋巴瘤,非霍奇金,头颈部肿瘤,磁共振成像 1
102+
抗生素关性腹泻 鼠李糖乳杆菌GG防治消化系统疾病的研究进展鼠李糖乳杆菌,腹泻,功能性胃肠病,肝脏疾病,幽门螺杆菌 0
103+
德州市图书馆 图书馆智慧化建设与融合创新服务研究图书馆;智慧化;阅读服务;融合创新 1
104+
维生素c 综述 维生素C防治2型糖尿病研究进展维生素C;2型糖尿病;氧化应激;自由基;抗氧化剂 0
105+
(白藜芦醇) 关键词:2型糖尿病 2型糖尿病大鼠心肌缺血再灌注损伤转录因子E2相关因子2/血红素氧合酶1信号通路的表达及白藜芦醇的干预研究糖尿病,2型,心肌缺血,再灌注损伤,白藜芦醇 1
106+
融资偏好 创新型企业产业风险、融资偏好与融资选择融资偏好;产业风险;融资选择 1
107+
星载激光雷达 星载激光雷达望远镜主镜超轻量化结构设计超轻量化;拓扑优化;集成优化;RMS;有限元仿真 1
108+
```
109+
110+
111+
### 数据集下载
112+
113+
114+
- [literature_search_rank](https://paddlenlp.bj.bcebos.com/applications/literature_search_rank.zip)
115+
116+
```
117+
├── data # 排序数据集
118+
├── test.csv # 测试集
119+
├── dev_pairwise.csv # 验证集
120+
└── train.csv # 训练集
121+
```
122+
123+
<a name="模型训练"></a>
124+
125+
## 5. 模型训练
126+
127+
**排序模型下载链接:**
128+
129+
130+
|Model|训练参数配置|硬件|MD5|
131+
| ------------ | ------------ | ------------ |-----------|
132+
|[ERNIE-Gram-Sort](https://bj.bcebos.com/v1/paddlenlp/models/ernie_gram_sort.zip)|<div style="width: 150pt">epoch:3 lr:5E-5 bs:64 max_len:64 </div>|<div style="width: 100pt">4卡 v100-16g</div>|d24ece68b7c3626ce6a24baa58dd297d|
133+
134+
135+
### 训练环境说明
136+
137+
138+
- NVIDIA Driver Version: 440.64.00
139+
- Ubuntu 16.04.6 LTS (Docker)
140+
- Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
141+
142+
143+
### 单机单卡训练/单机多卡训练
144+
145+
这里采用单机多卡方式进行训练,通过如下命令,指定 GPU 0,1,2,3 卡。如果采用单机单卡训练,只需要把`--gpu`参数设置成单卡的卡号即可
146+
147+
训练的命令如下:
148+
149+
```
150+
unset CUDA_VISIBLE_DEVICES
151+
python -u -m paddle.distributed.launch --gpus "0,1,2,3" --log_dir="logs" train_ce.py \
152+
--device gpu \
153+
--train_set data/train.csv \
154+
--test_file data/dev_pairwise.csv \
155+
--save_dir ./checkpoints \
156+
--model_name_or_path rocketqa-base-cross-encoder \
157+
--batch_size 32 \
158+
--save_steps 10000 \
159+
--max_seq_len 384 \
160+
--learning_rate 1E-5 \
161+
--weight_decay 0.01 \
162+
--warmup_proportion 0.0 \
163+
--logging_steps 10 \
164+
--seed 1 \
165+
--epochs 3 \
166+
--eval_step 1000
167+
```
168+
也可以运行bash脚本:
169+
170+
```
171+
sh scripts/train_ce.sh
172+
```
173+
174+
<a name="评估"></a>
175+
176+
## 6. 评估
177+
178+
179+
```
180+
python evaluate.py --model_name_or_path rocketqa-base-cross-encoder \
181+
--init_from_ckpt checkpoints/model_80000/model_state.pdparams \
182+
--test_file data/dev_pairwise.csv
183+
```
184+
也可以运行bash脚本:
185+
186+
```
187+
sh scripts/evaluate_ce.sh
188+
```
189+
190+
191+
成功运行后会输出下面的指标:
192+
193+
```
194+
eval_dev auc:0.829
195+
```
196+
197+
<a name="预测"></a>
198+
199+
## 7. 预测
200+
201+
### 准备预测数据
202+
203+
待预测数据为 tab 分隔的 tsv 文件,每一行为 1 个文本 Pair,和文本pair的语义索引相似度,(该相似度由召回模型算出,仅供参考),部分示例如下:
204+
205+
```
206+
中西方语言与文化的差异 第二语言习得的一大障碍就是文化差异。 0.5160342454910278
207+
中西方语言与文化的差异 跨文化视角下中国文化对外传播路径琐谈跨文化,中国文化,传播,翻译 0.5145505666732788
208+
中西方语言与文化的差异 从中西方民族文化心理的差异看英汉翻译语言,文化,民族文化心理,思维方式,翻译 0.5141439437866211
209+
中西方语言与文化的差异 中英文化差异对翻译的影响中英文化,差异,翻译的影响 0.5138794183731079
210+
中西方语言与文化的差异 浅谈文化与语言习得文化,语言,文化与语言的关系,文化与语言习得意识,跨文化交际 0.5131710171699524
211+
```
212+
213+
214+
215+
### 开始预测
216+
217+
以上述 demo 数据为例,运行如下命令基于我们开源的rocketqa模型开始计算文本 Pair 的语义相似度:
218+
219+
```shell
220+
unset CUDA_VISIBLE_DEVICES
221+
python predict.py \
222+
--device 'gpu' \
223+
--params_path checkpoints/model_80000/model_state.pdparams \
224+
--model_name_or_path rocketqa-base-cross-encoder \
225+
--test_set data/test.csv \
226+
--topk 10 \
227+
--batch_size 128 \
228+
--max_seq_length 384
229+
```
230+
也可以直接执行下面的命令:
231+
232+
```
233+
sh scripts/predict_ce.sh
234+
```
235+
得到下面的输出,分别是query,title和对应的预测概率:
236+
237+
```
238+
{'text_a': '加强科研项目管理有效促进医学科研工作', 'text_b': '高校\\十四五\\规划中学科建设要处理好五对关系\\十四五\\规划,学科建设,科技创新,人才培养', 'pred_prob': 0.7076062}
239+
{'text_a': '加强科研项目管理有效促进医学科研工作', 'text_b': '校企科研合作项目管理模式创新校企科研合作项目,管理模式,问题,创新', 'pred_prob': 0.64633846}
240+
{'text_a': '加强科研项目管理有效促进医学科研工作', 'text_b': '科研项目管理策略科研项目,项目管理,实施,必要性,策略', 'pred_prob': 0.63166416}
241+
{'text_a': '加强科研项目管理有效促进医学科研工作', 'text_b': '高校科研项目经费管理流程优化研究——以z大学为例高校,科研项目经费\\全流程\\管理,流程优化', 'pred_prob': 0.60351866}
242+
{'text_a': '加强科研项目管理有效促进医学科研工作', 'text_b': '关于推进我院科研发展进程的相关问题研究医院科研,主体,环境,信息化', 'pred_prob': 0.5688347}
243+
{'text_a': '加强科研项目管理有效促进医学科研工作', 'text_b': '医学临床科研选题原则和方法医学临床,科学研究,选题', 'pred_prob': 0.55190295}
244+
```
245+
246+
<a name="部署"></a>
247+
248+
## 8. 部署
249+
250+
### 动转静导出
251+
252+
首先把动态图模型转换为静态图:
253+
254+
```
255+
python export_model.py \
256+
--params_path checkpoints/model_80000/model_state.pdparams \
257+
--model_name_or_path rocketqa-base-cross-encoder \
258+
--output_path=./output
259+
```
260+
也可以运行下面的bash脚本:
261+
262+
```
263+
sh scripts/export_model.sh
264+
```
265+
266+
### Paddle Inference
267+
268+
使用PaddleInference
269+
270+
```
271+
python deploy/python/predict.py --model_dir ./output \
272+
--input_file data/test.csv \
273+
--model_name_or_path rocketqa-base-cross-encoder
274+
```
275+
也可以运行下面的bash脚本:
276+
277+
```
278+
sh deploy/python/deploy.sh
279+
```
280+
得到下面的输出,输出的是样本的query,title以及对应的概率:
281+
282+
```
283+
Data: {'query': '加强科研项目管理有效促进医学科研工作', 'title': '科研项目管理策略科研项目,项目管理,实施,必要性,策略'} prob: 0.5479063987731934
284+
Data: {'query': '加强科研项目管理有效促进医学科研工作', 'title': '关于推进我院科研发展进程的相关问题研究医院科研,主体,环境,信息化'} prob: 0.5151925086975098
285+
Data: {'query': '加强科研项目管理有效促进医学科研工作', 'title': '深圳科技计划对高校科研项目资助现状分析与思考基础研究,高校,科技计划,科技创新'} prob: 0.42983829975128174
286+
Data: {'query': '加强科研项目管理有效促进医学科研工作', 'title': '普通高校科研管理模式的优化与创新普通高校,科研,科研管理'} prob: 0.465454638004303
287+
```
288+
289+
### Paddle Serving部署
290+
291+
Paddle Serving 的详细文档请参考 [Pipeline_Design](https://github.com/PaddlePaddle/Serving/blob/v0.7.0/doc/Python_Pipeline/Pipeline_Design_CN.md)[Serving_Design](https://github.com/PaddlePaddle/Serving/blob/v0.7.0/doc/Serving_Design_CN.md),首先把静态图模型转换成Serving的格式:
292+
293+
```
294+
python export_to_serving.py \
295+
--dirname "output" \
296+
--model_filename "inference.pdmodel" \
297+
--params_filename "inference.pdiparams" \
298+
--server_path "serving_server" \
299+
--client_path "serving_client" \
300+
--fetch_alias_names "predict"
301+
302+
```
303+
304+
参数含义说明
305+
* `dirname`: 需要转换的模型文件存储路径,Program 结构文件和参数文件均保存在此目录。
306+
* `model_filename`: 存储需要转换的模型 Inference Program 结构的文件名称。如果设置为 None ,则使用 `__model__` 作为默认的文件名
307+
* `params_filename`: 存储需要转换的模型所有参数的文件名称。当且仅当所有模型参数被保>存在一个单独的二进制文件中,它才需要被指定。如果模型参数是存储在各自分离的文件中,设置它的值为 None
308+
* `server_path`: 转换后的模型文件和配置文件的存储路径。默认值为 serving_server
309+
* `client_path`: 转换后的客户端配置文件存储路径。默认值为 serving_client
310+
* `fetch_alias_names`: 模型输出的别名设置,比如输入的 input_ids 等,都可以指定成其他名字,默认不指定
311+
* `feed_alias_names`: 模型输入的别名设置,比如输出 pooled_out 等,都可以重新指定成其他模型,默认不指定
312+
313+
也可以运行下面的 bash 脚本:
314+
```
315+
sh scripts/export_to_serving.sh
316+
```
317+
Paddle Serving的部署有两种方式,第一种方式是Pipeline的方式,第二种是C++的方式,下面分别介绍这两种方式的用法:
318+
319+
#### Pipeline方式
320+
321+
修改对应预训练模型的`Tokenizer`
322+
323+
```
324+
self.tokenizer = AutoTokenizer.from_pretrained('rocketqa-base-cross-encoder')
325+
```
326+
327+
启动 Pipeline Server:
328+
329+
```
330+
python web_service.py
331+
```
332+
333+
启动客户端调用 Server。
334+
335+
首先修改rpc_client.py中需要预测的样本:
336+
337+
```
338+
list_data = [{"query":"加强科研项目管理有效促进医学科研工作","title":"科研项目管理策略科研项目,项目管理,实施,必要性,策略"}]`
339+
```
340+
然后运行:
341+
```
342+
python rpc_client.py
343+
```
344+
模型的输出为:
345+
346+
```
347+
PipelineClient::predict pack_data time:1662354188.422532
348+
PipelineClient::predict before time:1662354188.423034
349+
time to cost :0.016808509826660156 seconds
350+
(1,)
351+
[0.5479064]
352+
```
353+
可以看到客户端发送了1条文本,这条文本的相似的概率值。
354+
355+
#### C++的方式
356+
357+
启动C++的Serving:
358+
359+
```
360+
python -m paddle_serving_server.serve --model serving_server --port 8600 --gpu_id 0 --thread 5 --ir_optim True
361+
```
362+
也可以使用脚本:
363+
364+
```
365+
sh deploy/cpp/start_server.sh
366+
```
367+
Client 可以使用 http 或者 rpc 两种方式,rpc 的方式为:
368+
369+
```
370+
python deploy/cpp/rpc_client.py
371+
```
372+
运行的输出为:
373+
374+
```
375+
I0905 05:38:28.876770 28507 general_model.cpp:490] [client]logid=0,client_cost=158.124ms,server_cost=156.385ms.
376+
time to cost :0.15848731994628906 seconds
377+
[0.54790646]
378+
```
379+
可以看到服务端返回了相似度结果
380+
381+
或者使用 http 的客户端访问模式:
382+
383+
```
384+
python deploy/cpp/http_client.py
385+
```
386+
运行的输出为:
387+
```
388+
time to cost :0.13054680824279785 seconds
389+
0.5479064707850817
390+
```
391+
可以看到服务端返回了相似度结果
392+
393+
394+
## Reference
395+
396+
[1] Xiao, Dongling, Yu-Kun Li, Han Zhang, Yu Sun, Hao Tian, Hua Wu, and Haifeng Wang. “ERNIE-Gram: Pre-Training with Explicitly N-Gram Masked Language Modeling for Natural Language Understanding.” ArXiv:2010.12148 [Cs].
397+
398+
[2] Yingqi Qu, Yuchen Ding, Jing Liu, Kai Liu, Ruiyang Ren, Wayne Xin Zhao, Daxiang Dong, Hua Wu, Haifeng Wang:
399+
RocketQA: An Optimized Training Approach to Dense Passage Retrieval for Open-Domain Question Answering. NAACL-HLT 2021: 5835-5847

0 commit comments

Comments
 (0)