Skip to content

Commit 54df619

Browse files
authored
Add Retrieval based multi class classification (#3180)
* Add Retrieval based multi class classification * Update README.md and dtype * Update README.md
1 parent 6f953a9 commit 54df619

28 files changed

+2698
-16
lines changed

applications/text_classification/hierarchical/retrieval_based/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
# 1.基于语义索引的分类任务介绍
1717

18-
以前的分类任务中,标签信息作为无实际意义,独立存在的one-hot编码形式存在,这种做法会潜在的丢失标签的语义信息,本方案把文本分类任务中的标签信息转换成含有语义信息的语义向量,将文本分类任务转换成向量检索和匹配的任务。这样做的好处是对于一些类别标签不是很固定的场景,或者需要经常有一些新增类别的需求的情况非常合适。另外,对于一些新的相关的分类任务,这种方法也不需要模型重新学习或者设计一种新的模型结构来适应新的任务。总的来说,这种基于检索的文本分类方法能够有很好的拓展性,能够利用标签里面包含的语义信息,不需要重新进行学习。
18+
以前的分类任务中,标签信息作为无实际意义,独立存在的one-hot编码形式存在,这种做法会潜在的丢失标签的语义信息,本方案把文本分类任务中的标签信息转换成含有语义信息的语义向量,将文本分类任务转换成向量检索和匹配的任务。这样做的好处是对于一些类别标签不是很固定的场景,或者需要经常有一些新增类别的需求的情况非常合适。另外,对于一些新的相关的分类任务,这种方法也不需要模型重新学习或者设计一种新的模型结构来适应新的任务。总的来说,这种基于检索的文本分类方法能够有很好的拓展性,能够利用标签里面包含的语义信息,不需要重新进行学习。这种方法可以应用到相似标签推荐,文本标签标注,金融风险事件分类,政务信访分类等领域。
1919

20-
本方案是基于语义索引模型的分类,语义索引模型的目标是:给定输入文本,模型可以从海量候选召回库中**快速、准确**地召回一批语义相关文本。如果召回的文本带有类别标签,则可以把召回文本的类别标签作为给定输入文本的类别。本方案使用双塔模型,训练阶段引入In-batch Negatives 策略,使用hnswlib建立索引库,进行召回测试。最后利用召回的结果使用 Accuracy 指标来评估语义索引模型的分类的效果。
20+
本方案是基于语义索引模型的分类,语义索引模型的目标是:给定输入文本,模型可以从海量候选召回库中**快速、准确**地召回一批语义相关文本。基于语义索引的分类方法有两种,第一种方法是直接把标签变成召回库,即把输入文本和标签的文本进行匹配,第二种是利用召回的文本带有类别标签,把召回文本的类别标签作为给定输入文本的类别。本方案使用双塔模型,训练阶段引入In-batch Negatives 策略,使用hnswlib建立索引库,并把标签作为召回库,进行召回测试。最后利用召回的结果使用 Accuracy 指标来评估语义索引模型的分类的效果。
2121

2222

2323
**效果评估**
@@ -351,7 +351,7 @@ CUDA_VISIBLE_DEVICES=0 python utils/feature_extract.py \
351351

352352
```
353353
python utils/vector_insert.py \
354-
--vector_path ./data/corpus_embedding.npy
354+
--vector_path ./data/label_embedding.npy
355355
```
356356
也可以直接运行:
357357

applications/text_classification/hierarchical/retrieval_based/deploy/python/predict.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,9 @@ def extract_embedding(self, data, tokenizer):
200200
examples.append((input_ids, segment_ids))
201201

202202
batchify_fn = lambda samples, fn=Tuple(
203-
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
204-
Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment
203+
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input
204+
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"
205+
), # segment
205206
): fn(samples)
206207

207208
input_ids, segment_ids = batchify_fn(examples)
@@ -233,10 +234,12 @@ def predict(self, data, tokenizer):
233234
(input_ids, segment_ids, title_ids, title_segment_ids))
234235

235236
batchify_fn = lambda samples, fn=Tuple(
236-
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
237-
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
238-
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
239-
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
237+
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input
238+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"
239+
), # segment
240+
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input
241+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"
242+
), # segment
240243
): fn(samples)
241244

242245
query_ids, query_segment_ids, title_ids, title_segment_ids = batchify_fn(

applications/text_classification/hierarchical/retrieval_based/deploy/python/web_service.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ def preprocess(self, input_dicts, data_id, log_id):
5151
input_ids, segment_ids = convert_example([example], self.tokenizer)
5252
examples.append((input_ids, segment_ids))
5353
batchify_fn = lambda samples, fn=Tuple(
54-
Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # input
55-
Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # segment
54+
Pad(axis=0, pad_val=self.tokenizer.pad_token_id, dtype="int64"
55+
), # input
56+
Pad(axis=0, pad_val=self.tokenizer.pad_token_type_id, dtype="int64"
57+
), # segment
5658
): fn(samples)
5759
input_ids, segment_ids = batchify_fn(examples)
5860
feed_dict = {}

applications/text_classification/hierarchical/retrieval_based/predict.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,14 @@ def predict(model, data_loader):
8484
max_seq_length=args.max_seq_length,
8585
pad_to_max_seq_len=args.pad_to_max_seq_len)
8686
batchify_fn = lambda samples, fn=Tuple(
87-
Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_input
88-
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # query_segment
89-
Pad(axis=0, pad_val=tokenizer.pad_token_id), # title_input
90-
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # tilte_segment
87+
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"
88+
), # query_input
89+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"
90+
), # query_segment
91+
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"
92+
), # title_input
93+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"
94+
), # tilte_segment
9195
): [data for data in fn(samples)]
9296
valid_ds = load_dataset(read_text_pair,
9397
data_path=args.text_pair_file,

applications/text_classification/hierarchical/retrieval_based/scripts/train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# GPU training
22
root_path=inbatch
33
data_path=data
4-
python -u -m paddle.distributed.launch --gpus "1,2,3,4" \
4+
python -u -m paddle.distributed.launch --gpus "0,1,2,3" \
55
train.py \
66
--device gpu \
77
--save_dir ./checkpoints/${root_path} \

0 commit comments

Comments
 (0)