Skip to content

Commit 7932dd2

Browse files
1649759610tianxin
andauthored
add DiffCSE model (#2643)
* initial commit * refine readme * refine codestyle * refine readme * refine readme * fix model saving bug * initial commit * initial commit * initial commit * use common metric instead of eval_metrics.py and remove unuseful code Co-authored-by: tianxin <[email protected]>
1 parent 84e8026 commit 7932dd2

File tree

9 files changed

+2322
-0
lines changed

9 files changed

+2322
-0
lines changed
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# 无监督语义匹配模型 [DiffCSE](https://arxiv.org/pdf/2204.10298.pdf)
2+
3+
借鉴 [DiffCSE](https://arxiv.org/pdf/2204.10298.pdf) 的思路,实现了 DiffCSE 模型。相比于 SimCSE 模型,DiffCSE模型会更关注语句之间的差异性,具有精确的向量表示能力。DiffCSE 模型同样适合缺乏监督数据,但是又有大量无监督数据的匹配和检索场景。
4+
5+
## 快速开始
6+
### 代码结构说明
7+
8+
以下是本项目主要代码结构及说明:
9+
10+
```
11+
DiffCSE/
12+
├── model.py # DiffCSE 模型组网代码
13+
├── custom_ernie.py # 为适配 DiffCSE 模型,对ERNIE模型进行了部分修改
14+
├── data.py # 无监督语义匹配训练数据、测试数据的读取逻辑
15+
├── run_diffcse.py # 模型训练、评估、预测的主脚本
16+
├── utils.py # 包括一些常用的工具式函数
17+
├── run_train.sh # 模型训练的脚本
18+
├── run_eval.sh # 模型评估的脚本
19+
└── run_infer.sh # 模型预测的脚本
20+
```
21+
22+
### 模型训练
23+
默认使用无监督模式进行训练 DiffCSE,模型训练数据的数据样例如下所示,每行表示一条训练样本:
24+
```shell
25+
全年地方财政总收入3686.81亿元,比上年增长12.3%。
26+
“我对案情并不十分清楚,所以没办法提出批评,建议,只能希望通过质询,要求检察院对此做出说明。”他说。
27+
据调查结果显示:2015年微商行业总体市场规模达到1819.5亿元,预计2016年将达到3607.3亿元,增长率为98.3%。
28+
前往冈仁波齐需要办理目的地包含日喀则和阿里地区的边防证,外转沿途有一些补给点,可购买到干粮和饮料。
29+
```
30+
31+
可以运行如下命令,开始模型训练并且进行模型测试。
32+
33+
```shell
34+
gpu_ids=0
35+
export CUDA_VISIBLE_DEVICES=${gpu_ids}
36+
37+
log_dir="log_train"
38+
python -u -m paddle.distributed.launch --gpus ${gpu_ids} --log_dir ${log_dir} \
39+
run_diffcse.py \
40+
--mode "train" \
41+
--encoder_name "rocketqa-zh-dureader-query-encoder" \
42+
--generator_name "ernie-3.0-base-zh" \
43+
--discriminator_name "ernie-3.0-base-zh" \
44+
--max_seq_length "128" \
45+
--output_emb_size "32" \
46+
--train_set_file "your train_set path" \
47+
--eval_set_file "your dev_set path" \
48+
--save_dir "./checkpoints" \
49+
--log_dir ${log_dir} \
50+
--save_steps "50000" \
51+
--eval_steps "1000" \
52+
--epochs "3" \
53+
--batch_size "32" \
54+
--mlm_probability "0.15" \
55+
--lambda_weight "0.15" \
56+
--learning_rate "3e-5" \
57+
--weight_decay "0.01" \
58+
--warmup_proportion "0.01" \
59+
--seed "0" \
60+
--device "gpu"
61+
```
62+
63+
可支持配置的参数:
64+
* `mode`:可选,用于指明本次运行是模型训练、模型评估还是模型预测,仅支持[train, eval, infer]三种模式;默认为 infer。
65+
* `encoder_name`:可选,DiffCSE模型中用于向量抽取的模型名称;默认为 ernie-3.0-base-zh。
66+
* `generator_name`: 可选,DiffCSE模型中生成器的模型名称;默认为 ernie-3.0-base-zh。
67+
* `discriminator_name`: 可选,DiffCSE模型中判别器的模型名称;默认为 rocketqa-zh-dureader-query-encoder。
68+
* `max_seq_length`:可选,ERNIE-Gram 模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。
69+
* `output_emb_size`:可选,向量抽取模型输出向量的维度;默认为32。
70+
* `train_set_file`:可选,用于指定训练集的路径。
71+
* `eval_set_file`:可选,用于指定验证集的路径。
72+
* `save_dir`:可选,保存训练模型的目录;
73+
* `log_dir`:可选,训练训练过程中日志的输出目录;
74+
* `save_steps`:可选,用于指定模型训练过程中每隔多少 step 保存一次模型。
75+
* `eval_steps`:可选,用于指定模型训练过程中每隔多少 step,使用验证集评估一次模型。
76+
* `epochs`: 模型训练轮次,默认为3。
77+
* `batch_size`:可选,批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
78+
* `mlm_probability`:可选,利用生成器预测时,控制单词掩码的比例,默认为0.15。
79+
* `lambda_weight`:可选,控制RTD任务loss的占比,默认为0.15。
80+
* `learning_rate`:可选,Fine-tune 的最大学习率;默认为5e-5。
81+
* `weight_decay`:可选,控制正则项力度的参数,用于防止过拟合,默认为0.01。
82+
* `warmup_proportion`:可选,学习率 warmup 策略的比例,如果0.1,则学习率会在前10%训练step的过程中从0慢慢增长到 learning_rate, 而后再缓慢衰减,默认为0.01。
83+
* `seed`:可选,随机种子,默认为1000.
84+
* `device`: 选用什么设备进行训练,可选 cpu 或 gpu。如使用 gpu 训练则参数 gpus 指定GPU卡号。
85+
86+
程序运行时将会自动进行训练,评估。同时训练过程中会自动保存模型在指定的`save_dir`中。
87+
如:
88+
```text
89+
checkpoints/
90+
├── best
91+
│   ├── model_state.pdparams
92+
│   ├── tokenizer_config.json
93+
│   ├── special_tokens_map.json
94+
│   └── vocab.txt
95+
└── ...
96+
```
97+
98+
### 模型评估
99+
在模型评估时,需要使用带有标签的数据,以下展示了几条模型评估数据样例,每行表示一条训练样本,每行共计包含3列,分别是query1, query2, label:
100+
```shell
101+
右键单击此电脑选择属性,如下图所示 右键单击此电脑选择属性,如下图所示 5
102+
好医生解密||是什么,让美洲大蠊能美容还能救命 解密美洲大蠊巨大药用价值 1
103+
蒜香蜜汁烤鸡翅的做法 外香里嫩一口爆汁蒜蓉蜜汁烤鸡翅的做法 3
104+
项目计划书 篇2 简易项目计划书(参考模板) 2
105+
夏天幼儿园如何正确使用空调? 老师们该如何正确使用空调,让孩子少生病呢? 3
106+
```
107+
108+
109+
可以运行如下命令,进行模型评估。
110+
111+
```shell
112+
gpu_ids=0
113+
export CUDA_VISIBLE_DEVICES=${gpu_ids}
114+
115+
log_dir="log_eval"
116+
python -u -m paddle.distributed.launch --gpus ${gpu_ids} --log_dir ${log_dir} \
117+
run_diffcse.py \
118+
--mode "eval" \
119+
--encoder_name "rocketqa-zh-dureader-query-encoder" \
120+
--max_seq_length "128" \
121+
--output_emb_size "32" \
122+
--eval_set_file "your dev_set path" \
123+
--ckpt_dir "./checkpoints/best" \
124+
--batch_size "32" \
125+
--seed "0" \
126+
--device "gpu"
127+
```
128+
可支持配置的参数:
129+
* `ckpt_dir`: 用于指定进行模型评估的checkpoint路径。
130+
131+
其他参数解释同上。
132+
133+
### 基于动态图模型预测
134+
在模型预测时,需要给定待预测的两条文本,以下展示了几条模型预测的数据样例,每行表示一条训练样本,每行共计包含2列,分别是query1, query2:
135+
```shell
136+
韩国现代摩比斯2015招聘 韩国现代摩比斯2015校园招聘信息
137+
《DNF》封号减刑方法 被封一年怎么办? DNF封号减刑方法 封号一年怎么减刑
138+
原神手鞠游戏三个刷新位置一览 手鞠游戏三个刷新位置一览
139+
```
140+
141+
可以运行如下命令,进行模型预测:
142+
```shell
143+
gpu_ids=0
144+
export CUDA_VISIBLE_DEVICES=${gpu_ids}
145+
146+
log_dir="log_infer"
147+
python -u -m paddle.distributed.launch --gpus ${gpu_ids} --log_dir ${log_dir} \
148+
run_diffcse.py \
149+
--mode "infer" \
150+
--encoder_name "rocketqa-zh-dureader-query-encoder" \
151+
--max_seq_length "128" \
152+
--output_emb_size "32" \
153+
--infer_set_file "your test_set path \
154+
--ckpt_dir "./checkpoints/best" \
155+
--save_infer_path "./infer_result.txt" \
156+
--batch_size "32" \
157+
--seed "0" \
158+
--device "gpu"
159+
```
160+
161+
可支持配置的参数:
162+
* `infer_set_file`: 可选,用于指定测试集的路径。
163+
* `save_infer_path`: 可选,用于保存模型预测结果的文件路径。
164+
165+
其他参数解释同上。 待模型预测结束后,会将结果保存至save_infer_path参数指定的文件中。
166+
167+
168+
## Reference
169+
[1] Chuang Y S , Dangovski R , Luo H , et al. DiffCSE: Difference-based Contrastive Learning for Sentence Embeddings[J]. arXiv e-prints, 2022. https://arxiv.org/pdf/2204.10298.pdf.

0 commit comments

Comments
 (0)