Skip to content

Commit d39818c

Browse files
Feat/topic classifier (#1584)
1 parent c8264bf commit d39818c

File tree

3 files changed

+192
-14
lines changed

3 files changed

+192
-14
lines changed
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
{
2+
"dataset_reader": {
3+
"class_name": "basic_classification_reader",
4+
"class_sep": ";",
5+
"x": "text",
6+
"y": "topic",
7+
"data_path": "{DOWNLOADS_PATH}/dp_topics_downsampled_data/",
8+
"train" : "train.csv",
9+
"valid" : "valid.csv"
10+
},
11+
"dataset_iterator": {
12+
"class_name": "basic_classification_iterator",
13+
"seed": 42
14+
},
15+
"chainer": {
16+
"in": [
17+
"x"
18+
],
19+
"in_y": [
20+
"y"
21+
],
22+
"pipe": [
23+
{
24+
"class_name": "torch_transformers_preprocessor",
25+
"vocab_file": "{TRANSFORMER}",
26+
"do_lower_case": true,
27+
"max_seq_length": 128,
28+
"in": [
29+
"x"
30+
],
31+
"out": [
32+
"bert_features"
33+
]
34+
},
35+
{
36+
"id": "classes_vocab",
37+
"class_name": "simple_vocab",
38+
"fit_on": [
39+
"y"
40+
],
41+
"save_path": "{MODEL_PATH}/classes.dict",
42+
"load_path": "{MODEL_PATH}/classes.dict",
43+
"in": [
44+
"y"
45+
],
46+
"out": [
47+
"y_ids"
48+
]
49+
},
50+
{
51+
"in": [
52+
"y_ids"
53+
],
54+
"out": [
55+
"y_onehot"
56+
],
57+
"class_name": "one_hotter",
58+
"id": "my_one_hotter",
59+
"depth": "#classes_vocab.len",
60+
"single_vector": true
61+
},
62+
{
63+
"class_name": "torch_transformers_classifier",
64+
"one_hot_labels": true,
65+
"n_classes": "#classes_vocab.len",
66+
"return_probas": true,
67+
"pretrained_bert": "{TRANSFORMER}",
68+
"save_path": "{MODEL_PATH}/model",
69+
"load_path": "{MODEL_PATH}/model",
70+
"multilabel": true,
71+
"optimizer": "AdamW",
72+
"optimizer_parameters": {
73+
"lr": 1e-05
74+
},
75+
"learning_rate_drop_patience": 5,
76+
"learning_rate_drop_div": 2.0,
77+
"in": [
78+
"bert_features"
79+
],
80+
"in_y": [
81+
"y_onehot"
82+
],
83+
"out": [
84+
"y_pred_probas"
85+
]
86+
},
87+
{
88+
"in": "y_pred_probas",
89+
"out": "y_pred_ids",
90+
"class_name": "proba2labels",
91+
"max_proba": false,
92+
"confidence_threshold": 0.5
93+
},
94+
{
95+
"in": "y_pred_ids",
96+
"out": "y_pred_labels",
97+
"ref": "classes_vocab"
98+
},
99+
{
100+
"ref": "my_one_hotter",
101+
"in": "y_pred_ids",
102+
"out": "y_pred_onehot"
103+
}
104+
],
105+
"out": [
106+
"y_pred_labels"
107+
]
108+
},
109+
"train": {
110+
"epochs": 100,
111+
"batch_size": 64,
112+
"metrics": [
113+
{
114+
"name": "f1_macro",
115+
"inputs": [
116+
"y_onehot",
117+
"y_pred_onehot"
118+
]
119+
},
120+
{
121+
"name": "f1_weighted",
122+
"inputs": [
123+
"y_onehot",
124+
"y_pred_onehot"
125+
]
126+
},
127+
{
128+
"name": "accuracy",
129+
"inputs": [
130+
"y",
131+
"y_pred_labels"
132+
]
133+
},
134+
{
135+
"name": "roc_auc",
136+
"inputs": [
137+
"y_onehot",
138+
"y_pred_probas"
139+
]
140+
}
141+
],
142+
"validation_patience": 10,
143+
"val_every_n_epochs": 1,
144+
"log_every_n_epochs": 1,
145+
"log_every_n_batches": 100,
146+
"show_examples": false,
147+
"evaluation_targets": [
148+
"train",
149+
"valid",
150+
"test"
151+
],
152+
"tensorboard_log_dir": "{MODEL_PATH}/logs",
153+
"class_name": "torch_trainer"
154+
},
155+
"metadata": {
156+
"variables": {
157+
"TRANSFORMER": "distilbert-base-uncased",
158+
"ROOT_PATH": "~/.deeppavlov",
159+
"DOWNLOADS_PATH": "{ROOT_PATH}/downloads",
160+
"MODELS_PATH": "{ROOT_PATH}/models",
161+
"MODEL_PATH": "{MODELS_PATH}/classifiers/topic_distilbert_base_v0"
162+
},
163+
"download": [
164+
{
165+
"url": "http://files.deeppavlov.ai/datasets/dp_topics_downsampled_dataset_v0.tar.gz",
166+
"subdir": "{DOWNLOADS_PATH}"
167+
},
168+
{
169+
"url": "http://files.deeppavlov.ai/deeppavlov_data/classifiers/topic_distilbert_base_v0.tar.gz",
170+
"subdir": "{MODELS_PATH}/classifiers"
171+
}
172+
]
173+
}
174+
}

docs/features/models/classifiers.rst

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -234,19 +234,21 @@ the floating point labels are converted to integer labels according to the inter
234234
corresponding to `very negative`, `negative`, `neutral`, `positive`, `very positive` classes.
235235

236236

237-
+------------------+--------------------+------+-------------------------------------------------------------------------------------------------+-------------+--------+--------+-----------+
238-
| Task | Dataset | Lang | Model | Metric | Valid | Test | Downloads |
239-
+==================+====================+======+=================================================================================================+=============+========+========+===========+
240-
| Insult detection | `Insults`_ | En | :config:`English BERT <classifiers/insults_kaggle_bert.json>` | ROC-AUC | 0.9327 | 0.8602 | 1.1 Gb |
241-
+------------------+--------------------+ +-------------------------------------------------------------------------------------------------+-------------+--------+--------+-----------+
242-
| Sentiment |`SST`_ | | :config:`5-classes SST on conversational BERT <classifiers/sentiment_sst_conv_bert.json>` | Accuracy | 0.6293 | 0.6626 | 1.1 Gb |
243-
+------------------+--------------------+------+-------------------------------------------------------------------------------------------------+-------------+--------+--------+-----------+
244-
| Sentiment |`Twitter mokoron`_ | Ru | :config:`RuWiki+Lenta emb w/o preprocessing <classifiers/sentiment_twitter.json>` | F1-macro | 0.9965 | 0.9961 | 6.2 Gb |
245-
+ +--------------------+ +-------------------------------------------------------------------------------------------------+-------------+--------+--------+-----------+
246-
| |`RuSentiment`_ | | :config:`Multi-language BERT <classifiers/rusentiment_bert.json>` | F1-weighted | 0.6787 | 0.7005 | 1.3 Gb |
247-
+ + + +-------------------------------------------------------------------------------------------------+ +--------+--------+-----------+
248-
| | | | :config:`Conversational RuBERT <classifiers/rusentiment_convers_bert.json>` | | 0.739 | 0.7724 | 1.5 Gb |
249-
+------------------+--------------------+------+-------------------------------------------------------------------------------------------------+-------------+--------+--------+-----------+
237+
+------------------+----------------------+------+-------------------------------------------------------------------------------------------------+-------------+--------------+--------------+-----------+
238+
| Task | Dataset | Lang | Model | Metric | Valid | Test | Downloads |
239+
+==================+======================+======+=================================================================================================+=============+==============+==============+===========+
240+
| Insult detection | `Insults`_ | En | :config:`English BERT <classifiers/insults_kaggle_bert.json>` | ROC-AUC | 0.9327 | 0.8602 | 1.1 Gb |
241+
+------------------+----------------------+ +-------------------------------------------------------------------------------------------------+-------------+--------------+--------------+-----------+
242+
| Sentiment |`SST`_ | | :config:`5-classes SST on conversational BERT <classifiers/sentiment_sst_conv_bert.json>` | Accuracy | 0.6293 | 0.6626 | 1.1 Gb |
243+
+------------------+----------------------+------+-------------------------------------------------------------------------------------------------+-------------+--------------+--------------+-----------+
244+
| Sentiment |`Twitter mokoron`_ | Ru | :config:`RuWiki+Lenta emb w/o preprocessing <classifiers/sentiment_twitter.json>` | F1-macro | 0.9965 | 0.9961 | 6.2 Gb |
245+
+ +----------------------+ +-------------------------------------------------------------------------------------------------+-------------+--------------+--------------+-----------+
246+
| |`RuSentiment`_ | | :config:`Multilingual BERT <classifiers/rusentiment_bert.json>` | F1-weighted | 0.6787 | 0.7005 | 1.3 Gb |
247+
+ + + +-------------------------------------------------------------------------------------------------+ +--------------+--------------+-----------+
248+
| | | | :config:`Conversational RuBERT <classifiers/rusentiment_convers_bert.json>` | | 0.739 | 0.7724 | 1.5 Gb |
249+
+------------------+----------------------+------+-------------------------------------------------------------------------------------------------+-------------+--------------+--------------+-----------+
250+
| Topics | `DeepPavlov Topics`_ | En | :config:`Distil BERT base uncased <classifiers/topics_distilbert_base_uncased.json>` | F1-w / F1-m | 0.877/0.830 | 0.878/0.831 | 0.7 Gb |
251+
+------------------+----------------------+------+-------------------------------------------------------------------------------------------------+-------------+--------------+--------------+-----------+
250252

251253
.. _`DSTC 2`: http://camdial.org/~mh521/dstc/
252254
.. _`Insults`: https://www.kaggle.com/c/detecting-insults-in-social-commentary
@@ -257,6 +259,7 @@ corresponding to `very negative`, `negative`, `neutral`, `positive`, `very posit
257259
.. _`Yahoo-L31`: https://webscope.sandbox.yahoo.com/catalog.php?datatype=l
258260
.. _`Yahoo-L6`: https://webscope.sandbox.yahoo.com/catalog.php?datatype=l
259261
.. _`SST`: https://nlp.stanford.edu/sentiment/index.html
262+
.. _`DeepPavlov Topics`: https://deeppavlov.ai/datasets/topics
260263

261264
GLUE Benchmark
262265
--------------

tests/test_quick_start.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@
118118
("classifiers/glue/glue_rte_roberta_mnli.json", "classifiers", ('TI',)): [TWO_ARGUMENTS_INFER_CHECK],
119119
("classifiers/superglue/superglue_copa_roberta.json", "classifiers", ('TI',)): [LIST_ARGUMENTS_INFER_CHECK],
120120
("classifiers/superglue/superglue_boolq_roberta_mnli.json", "classifiers", ('TI',)): [TWO_ARGUMENTS_INFER_CHECK],
121-
("classifiers/superglue/superglue_record_roberta.json", "classifiers", ('TI',)): [RECORD_ARGUMENTS_INFER_CHECK]
121+
("classifiers/superglue/superglue_record_roberta.json", "classifiers", ('TI',)): [RECORD_ARGUMENTS_INFER_CHECK],
122+
("classifiers/topics_distilbert_base_uncased.json", "classifiers", ('TI',)): [ONE_ARGUMENT_INFER_CHECK]
122123
},
123124
"distil": {
124125
("classifiers/paraphraser_convers_distilrubert_2L.json", "distil", ('IP')): [TWO_ARGUMENTS_INFER_CHECK],

0 commit comments

Comments
 (0)