|
12 | 12 | * [Task 4: Regression](#task-4) |
13 | 13 | * [Task 5: Sentiment Analysis](#task-5) |
14 | 14 | * [Task 6: Question Paraphrase](#task-6) |
| 15 | + * [Task 7: Knowledge Distillation for Model Compression](#task-7) |
| 16 | + 1. [Compression for Query Binary Classifier](#task-7.1) |
| 17 | + 2. [Compression for Text Matching Model](#task-7.2) |
| 18 | + 3. [Compression for Slot Filling Model](#task-7.3) |
| 19 | + 4. [Compression for MRC Model](#task-7.4) |
15 | 20 | * [Advanced Usage](#advanced-usage) |
16 | 21 | * [Extra Feature Support](#extra-feature) |
17 | 22 | * [Learning Rate Decay](#lr-decay) |
@@ -382,6 +387,61 @@ This task is to determine whether a pair of questions are semantically equivalen |
382 | 387 | |
383 | 388 | *Tips: the model file and train log file can be found in JSON config file's outputs/save_base_dir.* |
384 | 389 |
|
| 390 | +### <span id="task-7">Task 7: Knowledge Distillation for Model Compression</span> |
| 391 | + |
| 392 | +Knowledge Distillation is a common method to compress model in order to improve inference speed. Here are some reference papers: |
| 393 | +- [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531) |
| 394 | +- [Model Compression with Multi-Task Knowledge Distillation for Web-scale Question Answering System](https://arxiv.org/abs/1904.09636) |
| 395 | + |
| 396 | +#### <span id="task-7.1">7.1: Compression for Query Binary Classifier</span> |
| 397 | +This task is to train a query regression model to learn from a heavy teacher model such as BERT based query classifier model. The training process is to minimized the score difference between the student model output and teacher model output. |
| 398 | +- ***Dataset*** |
| 399 | +*PROJECT_ROOT/dataset/knowledge_distillation/query_binary_classifier*: |
| 400 | + * *train.tsv* and *valid.tsv*: two columns, namely **Query** and **Score**. |
| 401 | + **Score** is the output score of a heavy teacher model (BERT base finetune model), which is the soft label to be learned by student model as knowledge. |
| 402 | + * *test.tsv*: two columns, namely **Query** and **Label**. |
| 403 | + **Label** is a binary value which 0 means negtive and 1 means positive. |
| 404 | + |
| 405 | + In the meanwhile, you can also replace with your own dataset for compression task trainning. |
| 406 | + |
| 407 | +- ***Usage*** |
| 408 | + |
| 409 | + 1. Train student model |
| 410 | + ```bash |
| 411 | + cd PROJECT_ROOT |
| 412 | + python train.py --conf_path=model_zoo/nlp_tasks/knowledge_distillation/conf_kdqbc_bilstmattn_cnn.json |
| 413 | + ``` |
| 414 | + |
| 415 | + 2. Test student model |
| 416 | + ```bash |
| 417 | + cd PROJECT_ROOT |
| 418 | + python test.py --conf_path=models/kdqbc_bilstmattn_cnn/train/conf_kdqbc_bilstmattn_cnn.json --previous_model_path models/kdqbc_bilstmattn_cnn/train/model.nb --predict_output_path models/kdqbc_bilstmattn_cnn/test/test.tsv --test_data_path dataset/knowledge_distillation/query_binary_classifier/test.tsv |
| 419 | + ``` |
| 420 | + |
| 421 | + 3. Calculate AUC metric |
| 422 | + ```bash |
| 423 | + cd PROJECT_ROOT |
| 424 | + python tools/AUC.py --input_file models/kdqbc_bilstmattn_cnn/test/test.tsv --predict_index 2 --label_index 1 |
| 425 | + ``` |
| 426 | + |
| 427 | + *Tips: you can try different models by running different JSON config files.* |
| 428 | + |
| 429 | +- ***Result*** |
| 430 | + |
| 431 | + The AUC of student model is very close to that of teacher model and its inference speed is 3.5X~4X times faster. |
| 432 | + |
| 433 | + |Model|AUC| |
| 434 | + |-----|---| |
| 435 | + |Teacher|0.9112| |
| 436 | + |Student-BiLSTMAttn+TextCNN (NeuronBlocks)|0.8941| |
| 437 | + |
| 438 | + *Tips: the model file and train log file can be found in JSON config file's outputs/save_base_dir.* |
| 439 | +
|
| 440 | +#### <span id="task-7.2">7.2: Compression for Text Matching Model (ongoing)</span> |
| 441 | +#### <span id="task-7.3">7.3: Compression for Slot Filling Model (ongoing)</span> |
| 442 | +#### <span id="task-7.4">7.4: Compression for MRC (ongoing)</span> |
| 443 | +
|
| 444 | +
|
385 | 445 | ## <span id="advanced-usage">Advanced Usage</span> |
386 | 446 |
|
387 | 447 | After building a model, the next goal is to train a model with good performance. It depends on a highly expressive model and tricks of the model training. NeuronBlocks provides some tricks of model training. |
|
0 commit comments