|
| 1 | + |
| 2 | + **目录** |
| 3 | + |
| 4 | +* [背景介绍](#背景介绍) |
| 5 | +* [CrossEncoder](#CrossEncoder) |
| 6 | + * [1. 技术方案和评估指标](#技术方案) |
| 7 | + * [2. 环境依赖](#环境依赖) |
| 8 | + * [3. 代码结构](#代码结构) |
| 9 | + * [4. 数据准备](#数据准备) |
| 10 | + * [5. 模型训练](#模型训练) |
| 11 | + * [6. 评估](#开始评估) |
| 12 | + * [7. 预测](#预测) |
| 13 | + * [8. 部署](#部署) |
| 14 | + |
| 15 | +<a name="背景介绍"></a> |
| 16 | + |
| 17 | +# 背景介绍 |
| 18 | + |
| 19 | +基于RocketQA的CrossEncoder训练的单塔模型,该模型用于搜索的排序阶段,对召回的结果进行重新排序的作用。 |
| 20 | + |
| 21 | + |
| 22 | +<a name="CrossEncoder"></a> |
| 23 | + |
| 24 | +# CrossEncoder |
| 25 | + |
| 26 | +<a name="技术方案"></a> |
| 27 | + |
| 28 | +## 1. 技术方案和评估指标 |
| 29 | + |
| 30 | +### 技术方案 |
| 31 | + |
| 32 | +加载基于ERNIE 3.0训练过的RocketQA单塔CrossEncoder模型。 |
| 33 | + |
| 34 | + |
| 35 | +### 评估指标 |
| 36 | + |
| 37 | +(1)采用 AUC 指标来评估排序模型的排序效果。 |
| 38 | + |
| 39 | +**效果评估** |
| 40 | + |
| 41 | +| 训练方式 | 模型 | AUC | |
| 42 | +| ------------ | ------------ |------------ | |
| 43 | +| pairwise| ERNIE-Gram |0.801 | |
| 44 | +| CrossEncoder | rocketqa-base-cross-encoder |**0.835** | |
| 45 | + |
| 46 | +<a name="环境依赖"></a> |
| 47 | + |
| 48 | +## 2. 环境依赖和安装说明 |
| 49 | + |
| 50 | +**环境依赖** |
| 51 | + |
| 52 | +* python >= 3.7 |
| 53 | +* paddlepaddle >= 2.3.7 |
| 54 | +* paddlenlp >= 2.3 |
| 55 | +* pandas >= 0.25.1 |
| 56 | +* scipy >= 1.3.1 |
| 57 | + |
| 58 | +<a name="代码结构"></a> |
| 59 | + |
| 60 | +## 3. 代码结构 |
| 61 | + |
| 62 | +以下是本项目主要代码结构及说明: |
| 63 | + |
| 64 | +``` |
| 65 | +ernie_matching/ |
| 66 | +├── deply # 部署 |
| 67 | + ├── cpp |
| 68 | + ├── rpc_client.py # RPC 客户端的bash脚本 |
| 69 | + ├── http_client.py # http 客户端的bash文件 |
| 70 | + └── start_server.sh # 启动C++服务的脚本 |
| 71 | + └── python |
| 72 | + ├── deploy.sh # 预测部署bash脚本 |
| 73 | + ├── config_nlp.yml # Pipeline 的配置文件 |
| 74 | + ├── web_service.py # Pipeline 服务端的脚本 |
| 75 | + ├── rpc_client.py # Pipeline RPC客户端的脚本 |
| 76 | + └── predict.py # python 预测部署示例 |
| 77 | +|—— scripts |
| 78 | + ├── export_model.sh # 动态图参数导出静态图参数的bash文件 |
| 79 | + ├── export_to_serving.sh # 导出 Paddle Serving 模型格式的bash文件 |
| 80 | + ├── train_ce.sh # 匹配模型训练的bash文件 |
| 81 | + ├── evaluate_ce.sh # 评估验证文件bash脚本 |
| 82 | + ├── predict_ce.sh # 匹配模型预测脚本的bash文件 |
| 83 | +├── export_model.py # 动态图参数导出静态图参数脚本 |
| 84 | +├── export_to_serving.py # 导出 Paddle Serving 模型格式的脚本 |
| 85 | +├── data.py # 训练样本的转换逻辑 |
| 86 | +├── train_ce.py # 模型训练脚本 |
| 87 | +├── evaluate.py # 评估验证文件 |
| 88 | +├── predict.py # Pair-wise 模型预测脚本,输出文本对是相似度 |
| 89 | +
|
| 90 | +``` |
| 91 | + |
| 92 | +<a name="数据准备"></a> |
| 93 | + |
| 94 | +## 4. 数据准备 |
| 95 | + |
| 96 | +### 数据集说明 |
| 97 | + |
| 98 | +样例数据如下: |
| 99 | +``` |
| 100 | +(小学数学教材比较) 关键词:新加坡 新加坡与中国数学教材的特色比较数学教材,教材比较,问题解决 0 |
| 101 | +徐慧新疆肿瘤医院 头颈部非霍奇金淋巴瘤扩散加权成像ADC值与Ki-67表达相关性分析淋巴瘤,非霍奇金,头颈部肿瘤,磁共振成像 1 |
| 102 | +抗生素关性腹泻 鼠李糖乳杆菌GG防治消化系统疾病的研究进展鼠李糖乳杆菌,腹泻,功能性胃肠病,肝脏疾病,幽门螺杆菌 0 |
| 103 | +德州市图书馆 图书馆智慧化建设与融合创新服务研究图书馆;智慧化;阅读服务;融合创新 1 |
| 104 | +维生素c 综述 维生素C防治2型糖尿病研究进展维生素C;2型糖尿病;氧化应激;自由基;抗氧化剂 0 |
| 105 | +(白藜芦醇) 关键词:2型糖尿病 2型糖尿病大鼠心肌缺血再灌注损伤转录因子E2相关因子2/血红素氧合酶1信号通路的表达及白藜芦醇的干预研究糖尿病,2型,心肌缺血,再灌注损伤,白藜芦醇 1 |
| 106 | +融资偏好 创新型企业产业风险、融资偏好与融资选择融资偏好;产业风险;融资选择 1 |
| 107 | +星载激光雷达 星载激光雷达望远镜主镜超轻量化结构设计超轻量化;拓扑优化;集成优化;RMS;有限元仿真 1 |
| 108 | +``` |
| 109 | + |
| 110 | + |
| 111 | +### 数据集下载 |
| 112 | + |
| 113 | + |
| 114 | +- [literature_search_rank](https://paddlenlp.bj.bcebos.com/applications/literature_search_rank.zip) |
| 115 | + |
| 116 | +``` |
| 117 | +├── data # 排序数据集 |
| 118 | + ├── test.csv # 测试集 |
| 119 | + ├── dev_pairwise.csv # 验证集 |
| 120 | + └── train.csv # 训练集 |
| 121 | +``` |
| 122 | + |
| 123 | +<a name="模型训练"></a> |
| 124 | + |
| 125 | +## 5. 模型训练 |
| 126 | + |
| 127 | +**排序模型下载链接:** |
| 128 | + |
| 129 | + |
| 130 | +|Model|训练参数配置|硬件|MD5| |
| 131 | +| ------------ | ------------ | ------------ |-----------| |
| 132 | +|[ERNIE-Gram-Sort](https://bj.bcebos.com/v1/paddlenlp/models/ernie_gram_sort.zip)|<div style="width: 150pt">epoch:3 lr:5E-5 bs:64 max_len:64 </div>|<div style="width: 100pt">4卡 v100-16g</div>|d24ece68b7c3626ce6a24baa58dd297d| |
| 133 | + |
| 134 | + |
| 135 | +### 训练环境说明 |
| 136 | + |
| 137 | + |
| 138 | +- NVIDIA Driver Version: 440.64.00 |
| 139 | +- Ubuntu 16.04.6 LTS (Docker) |
| 140 | +- Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz |
| 141 | + |
| 142 | + |
| 143 | +### 单机单卡训练/单机多卡训练 |
| 144 | + |
| 145 | +这里采用单机多卡方式进行训练,通过如下命令,指定 GPU 0,1,2,3 卡。如果采用单机单卡训练,只需要把`--gpu`参数设置成单卡的卡号即可 |
| 146 | + |
| 147 | +训练的命令如下: |
| 148 | + |
| 149 | +``` |
| 150 | +unset CUDA_VISIBLE_DEVICES |
| 151 | +python -u -m paddle.distributed.launch --gpus "0,1,2,3" --log_dir="logs" train_ce.py \ |
| 152 | + --device gpu \ |
| 153 | + --train_set data/train.csv \ |
| 154 | + --test_file data/dev_pairwise.csv \ |
| 155 | + --save_dir ./checkpoints \ |
| 156 | + --model_name_or_path rocketqa-base-cross-encoder \ |
| 157 | + --batch_size 32 \ |
| 158 | + --save_steps 10000 \ |
| 159 | + --max_seq_len 384 \ |
| 160 | + --learning_rate 1E-5 \ |
| 161 | + --weight_decay 0.01 \ |
| 162 | + --warmup_proportion 0.0 \ |
| 163 | + --logging_steps 10 \ |
| 164 | + --seed 1 \ |
| 165 | + --epochs 3 \ |
| 166 | + --eval_step 1000 |
| 167 | +``` |
| 168 | +也可以运行bash脚本: |
| 169 | + |
| 170 | +``` |
| 171 | +sh scripts/train_ce.sh |
| 172 | +``` |
| 173 | + |
| 174 | +<a name="评估"></a> |
| 175 | + |
| 176 | +## 6. 评估 |
| 177 | + |
| 178 | + |
| 179 | +``` |
| 180 | +python evaluate.py --model_name_or_path rocketqa-base-cross-encoder \ |
| 181 | + --init_from_ckpt checkpoints/model_80000/model_state.pdparams \ |
| 182 | + --test_file data/dev_pairwise.csv |
| 183 | +``` |
| 184 | +也可以运行bash脚本: |
| 185 | + |
| 186 | +``` |
| 187 | +sh scripts/evaluate_ce.sh |
| 188 | +``` |
| 189 | + |
| 190 | + |
| 191 | +成功运行后会输出下面的指标: |
| 192 | + |
| 193 | +``` |
| 194 | +eval_dev auc:0.829 |
| 195 | +``` |
| 196 | + |
| 197 | +<a name="预测"></a> |
| 198 | + |
| 199 | +## 7. 预测 |
| 200 | + |
| 201 | +### 准备预测数据 |
| 202 | + |
| 203 | +待预测数据为 tab 分隔的 tsv 文件,每一行为 1 个文本 Pair,和文本pair的语义索引相似度,(该相似度由召回模型算出,仅供参考),部分示例如下: |
| 204 | + |
| 205 | +``` |
| 206 | +中西方语言与文化的差异 第二语言习得的一大障碍就是文化差异。 0.5160342454910278 |
| 207 | +中西方语言与文化的差异 跨文化视角下中国文化对外传播路径琐谈跨文化,中国文化,传播,翻译 0.5145505666732788 |
| 208 | +中西方语言与文化的差异 从中西方民族文化心理的差异看英汉翻译语言,文化,民族文化心理,思维方式,翻译 0.5141439437866211 |
| 209 | +中西方语言与文化的差异 中英文化差异对翻译的影响中英文化,差异,翻译的影响 0.5138794183731079 |
| 210 | +中西方语言与文化的差异 浅谈文化与语言习得文化,语言,文化与语言的关系,文化与语言习得意识,跨文化交际 0.5131710171699524 |
| 211 | +``` |
| 212 | + |
| 213 | + |
| 214 | + |
| 215 | +### 开始预测 |
| 216 | + |
| 217 | +以上述 demo 数据为例,运行如下命令基于我们开源的rocketqa模型开始计算文本 Pair 的语义相似度: |
| 218 | + |
| 219 | +```shell |
| 220 | +unset CUDA_VISIBLE_DEVICES |
| 221 | +python predict.py \ |
| 222 | + --device 'gpu' \ |
| 223 | + --params_path checkpoints/model_80000/model_state.pdparams \ |
| 224 | + --model_name_or_path rocketqa-base-cross-encoder \ |
| 225 | + --test_set data/test.csv \ |
| 226 | + --topk 10 \ |
| 227 | + --batch_size 128 \ |
| 228 | + --max_seq_length 384 |
| 229 | +``` |
| 230 | +也可以直接执行下面的命令: |
| 231 | + |
| 232 | +``` |
| 233 | +sh scripts/predict_ce.sh |
| 234 | +``` |
| 235 | +得到下面的输出,分别是query,title和对应的预测概率: |
| 236 | + |
| 237 | +``` |
| 238 | +{'text_a': '加强科研项目管理有效促进医学科研工作', 'text_b': '高校\\十四五\\规划中学科建设要处理好五对关系\\十四五\\规划,学科建设,科技创新,人才培养', 'pred_prob': 0.7076062} |
| 239 | +{'text_a': '加强科研项目管理有效促进医学科研工作', 'text_b': '校企科研合作项目管理模式创新校企科研合作项目,管理模式,问题,创新', 'pred_prob': 0.64633846} |
| 240 | +{'text_a': '加强科研项目管理有效促进医学科研工作', 'text_b': '科研项目管理策略科研项目,项目管理,实施,必要性,策略', 'pred_prob': 0.63166416} |
| 241 | +{'text_a': '加强科研项目管理有效促进医学科研工作', 'text_b': '高校科研项目经费管理流程优化研究——以z大学为例高校,科研项目经费\\全流程\\管理,流程优化', 'pred_prob': 0.60351866} |
| 242 | +{'text_a': '加强科研项目管理有效促进医学科研工作', 'text_b': '关于推进我院科研发展进程的相关问题研究医院科研,主体,环境,信息化', 'pred_prob': 0.5688347} |
| 243 | +{'text_a': '加强科研项目管理有效促进医学科研工作', 'text_b': '医学临床科研选题原则和方法医学临床,科学研究,选题', 'pred_prob': 0.55190295} |
| 244 | +``` |
| 245 | + |
| 246 | +<a name="部署"></a> |
| 247 | + |
| 248 | +## 8. 部署 |
| 249 | + |
| 250 | +### 动转静导出 |
| 251 | + |
| 252 | +首先把动态图模型转换为静态图: |
| 253 | + |
| 254 | +``` |
| 255 | +python export_model.py \ |
| 256 | + --params_path checkpoints/model_80000/model_state.pdparams \ |
| 257 | + --model_name_or_path rocketqa-base-cross-encoder \ |
| 258 | + --output_path=./output |
| 259 | +``` |
| 260 | +也可以运行下面的bash脚本: |
| 261 | + |
| 262 | +``` |
| 263 | +sh scripts/export_model.sh |
| 264 | +``` |
| 265 | + |
| 266 | +### Paddle Inference |
| 267 | + |
| 268 | +使用PaddleInference |
| 269 | + |
| 270 | +``` |
| 271 | +python deploy/python/predict.py --model_dir ./output \ |
| 272 | + --input_file data/test.csv \ |
| 273 | + --model_name_or_path rocketqa-base-cross-encoder |
| 274 | +``` |
| 275 | +也可以运行下面的bash脚本: |
| 276 | + |
| 277 | +``` |
| 278 | +sh deploy/python/deploy.sh |
| 279 | +``` |
| 280 | +得到下面的输出,输出的是样本的query,title以及对应的概率: |
| 281 | + |
| 282 | +``` |
| 283 | +Data: {'query': '加强科研项目管理有效促进医学科研工作', 'title': '科研项目管理策略科研项目,项目管理,实施,必要性,策略'} prob: 0.5479063987731934 |
| 284 | +Data: {'query': '加强科研项目管理有效促进医学科研工作', 'title': '关于推进我院科研发展进程的相关问题研究医院科研,主体,环境,信息化'} prob: 0.5151925086975098 |
| 285 | +Data: {'query': '加强科研项目管理有效促进医学科研工作', 'title': '深圳科技计划对高校科研项目资助现状分析与思考基础研究,高校,科技计划,科技创新'} prob: 0.42983829975128174 |
| 286 | +Data: {'query': '加强科研项目管理有效促进医学科研工作', 'title': '普通高校科研管理模式的优化与创新普通高校,科研,科研管理'} prob: 0.465454638004303 |
| 287 | +``` |
| 288 | + |
| 289 | +### Paddle Serving部署 |
| 290 | + |
| 291 | +Paddle Serving 的详细文档请参考 [Pipeline_Design](https://github.com/PaddlePaddle/Serving/blob/v0.7.0/doc/Python_Pipeline/Pipeline_Design_CN.md)和[Serving_Design](https://github.com/PaddlePaddle/Serving/blob/v0.7.0/doc/Serving_Design_CN.md),首先把静态图模型转换成Serving的格式: |
| 292 | + |
| 293 | +``` |
| 294 | +python export_to_serving.py \ |
| 295 | + --dirname "output" \ |
| 296 | + --model_filename "inference.pdmodel" \ |
| 297 | + --params_filename "inference.pdiparams" \ |
| 298 | + --server_path "serving_server" \ |
| 299 | + --client_path "serving_client" \ |
| 300 | + --fetch_alias_names "predict" |
| 301 | +
|
| 302 | +``` |
| 303 | + |
| 304 | +参数含义说明 |
| 305 | +* `dirname`: 需要转换的模型文件存储路径,Program 结构文件和参数文件均保存在此目录。 |
| 306 | +* `model_filename`: 存储需要转换的模型 Inference Program 结构的文件名称。如果设置为 None ,则使用 `__model__` 作为默认的文件名 |
| 307 | +* `params_filename`: 存储需要转换的模型所有参数的文件名称。当且仅当所有模型参数被保>存在一个单独的二进制文件中,它才需要被指定。如果模型参数是存储在各自分离的文件中,设置它的值为 None |
| 308 | +* `server_path`: 转换后的模型文件和配置文件的存储路径。默认值为 serving_server |
| 309 | +* `client_path`: 转换后的客户端配置文件存储路径。默认值为 serving_client |
| 310 | +* `fetch_alias_names`: 模型输出的别名设置,比如输入的 input_ids 等,都可以指定成其他名字,默认不指定 |
| 311 | +* `feed_alias_names`: 模型输入的别名设置,比如输出 pooled_out 等,都可以重新指定成其他模型,默认不指定 |
| 312 | + |
| 313 | +也可以运行下面的 bash 脚本: |
| 314 | +``` |
| 315 | +sh scripts/export_to_serving.sh |
| 316 | +``` |
| 317 | +Paddle Serving的部署有两种方式,第一种方式是Pipeline的方式,第二种是C++的方式,下面分别介绍这两种方式的用法: |
| 318 | + |
| 319 | +#### Pipeline方式 |
| 320 | + |
| 321 | +修改对应预训练模型的`Tokenizer`: |
| 322 | + |
| 323 | +``` |
| 324 | +self.tokenizer = AutoTokenizer.from_pretrained('rocketqa-base-cross-encoder') |
| 325 | +``` |
| 326 | + |
| 327 | +启动 Pipeline Server: |
| 328 | + |
| 329 | +``` |
| 330 | +python web_service.py |
| 331 | +``` |
| 332 | + |
| 333 | +启动客户端调用 Server。 |
| 334 | + |
| 335 | +首先修改rpc_client.py中需要预测的样本: |
| 336 | + |
| 337 | +``` |
| 338 | +list_data = [{"query":"加强科研项目管理有效促进医学科研工作","title":"科研项目管理策略科研项目,项目管理,实施,必要性,策略"}]` |
| 339 | +``` |
| 340 | +然后运行: |
| 341 | +``` |
| 342 | +python rpc_client.py |
| 343 | +``` |
| 344 | +模型的输出为: |
| 345 | + |
| 346 | +``` |
| 347 | +PipelineClient::predict pack_data time:1662354188.422532 |
| 348 | +PipelineClient::predict before time:1662354188.423034 |
| 349 | +time to cost :0.016808509826660156 seconds |
| 350 | +(1,) |
| 351 | +[0.5479064] |
| 352 | +``` |
| 353 | +可以看到客户端发送了1条文本,这条文本的相似的概率值。 |
| 354 | + |
| 355 | +#### C++的方式 |
| 356 | + |
| 357 | +启动C++的Serving: |
| 358 | + |
| 359 | +``` |
| 360 | +python -m paddle_serving_server.serve --model serving_server --port 8600 --gpu_id 0 --thread 5 --ir_optim True |
| 361 | +``` |
| 362 | +也可以使用脚本: |
| 363 | + |
| 364 | +``` |
| 365 | +sh deploy/cpp/start_server.sh |
| 366 | +``` |
| 367 | +Client 可以使用 http 或者 rpc 两种方式,rpc 的方式为: |
| 368 | + |
| 369 | +``` |
| 370 | +python deploy/cpp/rpc_client.py |
| 371 | +``` |
| 372 | +运行的输出为: |
| 373 | + |
| 374 | +``` |
| 375 | +I0905 05:38:28.876770 28507 general_model.cpp:490] [client]logid=0,client_cost=158.124ms,server_cost=156.385ms. |
| 376 | +time to cost :0.15848731994628906 seconds |
| 377 | +[0.54790646] |
| 378 | +``` |
| 379 | +可以看到服务端返回了相似度结果 |
| 380 | + |
| 381 | +或者使用 http 的客户端访问模式: |
| 382 | + |
| 383 | +``` |
| 384 | +python deploy/cpp/http_client.py |
| 385 | +``` |
| 386 | +运行的输出为: |
| 387 | +``` |
| 388 | +time to cost :0.13054680824279785 seconds |
| 389 | +0.5479064707850817 |
| 390 | +``` |
| 391 | +可以看到服务端返回了相似度结果 |
| 392 | + |
| 393 | + |
| 394 | +## Reference |
| 395 | + |
| 396 | +[1] Xiao, Dongling, Yu-Kun Li, Han Zhang, Yu Sun, Hao Tian, Hua Wu, and Haifeng Wang. “ERNIE-Gram: Pre-Training with Explicitly N-Gram Masked Language Modeling for Natural Language Understanding.” ArXiv:2010.12148 [Cs]. |
| 397 | + |
| 398 | +[2] Yingqi Qu, Yuchen Ding, Jing Liu, Kai Liu, Ruiyang Ren, Wayne Xin Zhao, Daxiang Dong, Hua Wu, Haifeng Wang: |
| 399 | +RocketQA: An Optimized Training Approach to Dense Passage Retrieval for Open-Domain Question Answering. NAACL-HLT 2021: 5835-5847 |
0 commit comments