Skip to content

Commit b68048f

Browse files
authored
Add Knowlege distillation to Tutorial.md (#4)
update Tutorial.md with knowledge distillation tasks
1 parent eaef926 commit b68048f

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

Tutorial.md

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
* [Task 4: Regression](#task-4)
1313
* [Task 5: Sentiment Analysis](#task-5)
1414
* [Task 6: Question Paraphrase](#task-6)
15+
* [Task 7: Knowledge Distillation for Model Compression](#task-7)
16+
1. [Compression for Query Binary Classifier](#task-7.1)
17+
2. [Compression for Text Matching Model](#task-7.2)
18+
3. [Compression for Slot Filling Model](#task-7.3)
19+
4. [Compression for MRC Model](#task-7.4)
1520
* [Advanced Usage](#advanced-usage)
1621
* [Extra Feature Support](#extra-feature)
1722
* [Learning Rate Decay](#lr-decay)
@@ -382,6 +387,61 @@ This task is to determine whether a pair of questions are semantically equivalen
382387
383388
*Tips: the model file and train log file can be found in JSON config file's outputs/save_base_dir.*
384389

390+
### <span id="task-7">Task 7: Knowledge Distillation for Model Compression</span>
391+
392+
Knowledge Distillation is a common method to compress model in order to improve inference speed. Here are some reference papers:
393+
- [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531)
394+
- [Model Compression with Multi-Task Knowledge Distillation for Web-scale Question Answering System](https://arxiv.org/abs/1904.09636)
395+
396+
#### <span id="task-7.1">7.1: Compression for Query Binary Classifier</span>
397+
This task is to train a query regression model to learn from a heavy teacher model such as BERT based query classifier model. The training process is to minimized the score difference between the student model output and teacher model output.
398+
- ***Dataset***
399+
*PROJECT_ROOT/dataset/knowledge_distillation/query_binary_classifier*:
400+
* *train.tsv* and *valid.tsv*: two columns, namely **Query** and **Score**.
401+
**Score** is the output score of a heavy teacher model (BERT base finetune model), which is the soft label to be learned by student model as knowledge.
402+
* *test.tsv*: two columns, namely **Query** and **Label**.
403+
**Label** is a binary value which 0 means negtive and 1 means positive.
404+
405+
In the meanwhile, you can also replace with your own dataset for compression task trainning.
406+
407+
- ***Usage***
408+
409+
1. Train student model
410+
```bash
411+
cd PROJECT_ROOT
412+
python train.py --conf_path=model_zoo/nlp_tasks/knowledge_distillation/conf_kdqbc_bilstmattn_cnn.json
413+
```
414+
415+
2. Test student model
416+
```bash
417+
cd PROJECT_ROOT
418+
python test.py --conf_path=models/kdqbc_bilstmattn_cnn/train/conf_kdqbc_bilstmattn_cnn.json --previous_model_path models/kdqbc_bilstmattn_cnn/train/model.nb --predict_output_path models/kdqbc_bilstmattn_cnn/test/test.tsv --test_data_path dataset/knowledge_distillation/query_binary_classifier/test.tsv
419+
```
420+
421+
3. Calculate AUC metric
422+
```bash
423+
cd PROJECT_ROOT
424+
python tools/AUC.py --input_file models/kdqbc_bilstmattn_cnn/test/test.tsv --predict_index 2 --label_index 1
425+
```
426+
427+
*Tips: you can try different models by running different JSON config files.*
428+
429+
- ***Result***
430+
431+
The AUC of student model is very close to that of teacher model and its inference speed is 3.5X~4X times faster.
432+
433+
|Model|AUC|
434+
|-----|---|
435+
|Teacher|0.9112|
436+
|Student-BiLSTMAttn+TextCNN (NeuronBlocks)|0.8941|
437+
438+
*Tips: the model file and train log file can be found in JSON config file's outputs/save_base_dir.*
439+
440+
#### <span id="task-7.2">7.2: Compression for Text Matching Model (ongoing)</span>
441+
#### <span id="task-7.3">7.3: Compression for Slot Filling Model (ongoing)</span>
442+
#### <span id="task-7.4">7.4: Compression for MRC (ongoing)</span>
443+
444+
385445
## <span id="advanced-usage">Advanced Usage</span>
386446
387447
After building a model, the next goal is to train a model with good performance. It depends on a highly expressive model and tricks of the model training. NeuronBlocks provides some tricks of model training.

0 commit comments

Comments
 (0)