Skip to content

Commit 5515597

Browse files
cherry-pick slim mkl-dnn doc (#22654)
1 parent f41449e commit 5515597

File tree

1 file changed

+65
-31
lines changed

1 file changed

+65
-31
lines changed
Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
# SLIM Quantization-aware training (QAT) on INT8 MKL-DNN
22

3-
This document describes how to use [Paddle Slim](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/advanced_usage/paddle_slim/paddle_slim.md) to convert a quantization-aware trained model to INT8 MKL-DNN quantized model. In **Release 1.5**, we have released the QAT1.0 MKL-DNN which enabled the INT8 MKL-DNN kernel for QAT trained model within 0.05% accuracy diff on GoogleNet, MobileNet-V1, MobileNet-V2, ResNet-101, ResNet-50, VGG16 and VGG19. In **Release 1.6**, QAT2.0 MKL-DNN, we did the performance optimization based on fake QAT models: ResNet50, ResNet101, Mobilenet-v1, Mobilenet-v2, VGG16 and VGG19 with the minor accuracy drop. Compared with Release 1.5, the QAT2.0 MKL-DNN got better performance gain on inference compared with fake QAT models but got a little bit bigger accuracy diff. We provide the accuracy benchmark both for QAT1.0 MKL-DNN and QAT2.0 MKL-DNN, and performance benchmark on QAT2.0 MKL-DNN.
3+
This document describes how to use [Paddle Slim](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/advanced_usage/paddle_slim/paddle_slim.md) to convert a quantization-aware trained model to INT8 MKL-DNN quantized model. In **Release 1.5**, we have released the QAT1.0 MKL-DNN which enabled the INT8 MKL-DNN kernel for QAT trained model within 0.05% accuracy diff on GoogleNet, MobileNet-V1, MobileNet-V2, ResNet-101, ResNet-50, VGG16 and VGG19. In **Release 1.6**, QAT2.0 MKL-DNN, we did the performance optimization based on fake QAT models: ResNet50, ResNet101, Mobilenet-v1, Mobilenet-v2, VGG16 and VGG19 with the minor accuracy drop. Compared with Release 1.5, the QAT2.0 MKL-DNN got better performance gain on inference compared with fake QAT models but got a little bit bigger accuracy diff. In **Release 1.7**, a support for [Ernie (NLP) QAT trained model](https://github.com/PaddlePaddle/benchmark/tree/master/Inference/c%2B%2B/ernie) was added to the QAT2.0 MKL-DNN. We provide the accuracy benchmark both for QAT1.0 MKL-DNN and QAT2.0 MKL-DNN, and performance benchmark on QAT2.0 MKL-DNN.
44

55
Notes:
66

77
* MKL-DNN and MKL are required. The performance gain can only be obtained with AVX512 series CPU servers.
8+
* INT8 accuracy is best on CPU servers supporting AVX512 VNNI extension.
89

910
## 0. Prerequisite
10-
You need to install at least PaddlePaddle-1.6 python package `pip install paddlepaddle==1.6`.
11+
You need to install at least PaddlePaddle-1.7 python package `pip install paddlepaddle==1.7`.
1112

1213
## 1. How to generate INT8 MKL-DNN QAT model
13-
You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quantization_mkldnn_pass.py). Users firstly use PaddleSlim quantization strategy to get a saved fake QAT model by [QuantizationFreezePass](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api), then use the `FakeQAT2MkldnnINT8KernelPass` to get the graph which can be run with MKL-DNN INT8 kernel. In Paddle Release 1.6, this pass supports `conv2d` and `depthwise_conv2d` ops with channel-wise quantization for weights. Apart from it, another pass called FakeQAT2MkldnnINT8PerfPass is available for use. This pass allows users to transform their QAT INT8 model into a highly performance-optimized model that is ran using INT8 MKL-DNN kernels.
14+
You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quantization_mkldnn_pass.py). Users firstly use PaddleSlim quantization strategy to get a saved fake QAT model by [QuantizationFreezePass](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api), then use the `QatInt8MkldnnPass` (from QAT1.0 MKL-DNN) to get a graph which can be run with MKL-DNN INT8 kernel. In Paddle Release 1.6, this pass supports `conv2d` and `depthwise_conv2d` ops with channel-wise quantization for weights. Apart from it, another pass called `Qat2Int8MkldnnPass` (from QAT2.0 MKL-DNN) is available for use. In Release 1.6, this pass additionally supports `pool2d` op and allows users to transform their QAT model into a highly performance-optimized INT8 model that is ran using INT8 MKL-DNN kernels. In Release 1.7, a support for `fc`, `reshape2` and `transpose2` ops was added to the pass.
1415

1516
```python
1617
import paddle.fluid as fluid
17-
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
18+
from paddle.fluid.contrib.slim.quantization import QatInt8MkldnnPass
19+
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass
1820
from paddle.fluid.framework import IrGraph
1921
from paddle.fluid import core
2022

@@ -23,14 +25,14 @@ You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quanti
2325
place = fluid.CPUPlace()
2426
# Convert the IrGraph to MKL-DNN supported INT8 IrGraph by using
2527
# QAT1.0 MKL-DNN
26-
# FakeQAT2MkldnnINT8KernelPass
27-
mkldnn_pass = FakeQAT2MkldnnINT8KernelPass(fluid.global_scope(), place)
28-
# Apply FakeQAT2MkldnnINT8KernelPass to IrGraph
28+
# QatInt8MkldnnPass
29+
mkldnn_pass = QatInt8MkldnnPass(fluid.global_scope(), place)
30+
# Apply QatInt8MkldnnPass to IrGraph
2931
mkldnn_pass.apply(graph)
3032
# QAT2.0 MKL-DNN
31-
# FakeQAT2MkldnnINT8PerfPass
32-
mkldnn_pass = FakeQAT2MkldnnINT8PerfPass(fluid.global_scope(), place, fluid.core, False)
33-
# Apply FakeQAT2MkldnnINT8PerfPass to IrGraph
33+
# Qat2Int8MkldnnPass, it requires a list of operators to be quantized
34+
mkldnn_pass = Qat2Int8MkldnnPass({'conv2d', 'pool2d'}, fluid.global_scope(), place, fluid.core, False)
35+
# Apply Qat2Int8MkldnnPass to IrGraph
3436
mkldnn_pass.apply(graph)
3537

3638
```
@@ -61,10 +63,11 @@ You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quanti
6163
| VGG16 | 71.74% | 71.75% | +0.01% | 89.96% | 89.73% | -0.23% |
6264
| VGG19 | 72.30% | 72.09% | -0.21% | 90.19% | 90.13% | -0.06% |
6365

66+
6467
>**III. QAT2.0 MKL-DNN C-API Performance on Intel(R) Xeon(R) Gold 6271**
6568
66-
| Model | FP32 Optimized Throughput (images/s) | INT8 QAT Throughput(images/s) | Ratio(INT8/FP32) |
67-
|:------------:|:------------------------------------:|:-----------------------------:|:----------------:|
69+
| Model | FP32 Optimized Throughput | INT8 QAT Throughput | Ratio(INT8/FP32) |
70+
|:------------:|:------------------------------------:|:-----------------------------:|:----------------:|
6871
| MobileNet-V1 | 73.98 | 227.73 | 3.08 |
6972
| MobileNet-V2 | 86.59 | 206.74 | 2.39 |
7073
| ResNet101 | 7.15 | 26.69 | 3.73 |
@@ -76,64 +79,95 @@ Notes:
7679

7780
* FP32 Optimized Throughput (images/s) is from [int8_mkldnn_quantization.md](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/inference/tests/api/int8_mkldnn_quantization.md).
7881

82+
>**IV. Ernie QAT2.0 MKL-DNN Accuracy on Intel(R) Xeon(R) Gold 6271**
83+
84+
| Model | FP32 Accuracy | QAT INT8 Accuracy | Accuracy Diff |
85+
|:------------:|:----------------------:|:----------------------:|:---------:|
86+
| Ernie | 0.80 | 0.82 | +0.02 |
87+
88+
89+
>**V. Ernie QAT2.0 MKL-DNN Performance on Intel(R) Xeon(R) Gold 6271**
90+
91+
| Threads | FP32 Latency (ms) | QAT INT8 Latency (ms) | Latency Diff |
92+
|:------------:|:----------------------:|:-------------------:|:---------:|
93+
| 1 thread | 252.131 | 93.8023 | 2.687x |
94+
| 20 threads | 29.1853 | 17.3765 | 1.680x |
95+
7996
## 3. How to reproduce the results
80-
Three steps to reproduce the above-mentioned accuracy results, and we take ResNet50 benchmark as an example:
81-
* ### Prepare dataset
97+
Three steps are needed to reproduce the above-mentioned accuracy and performance results. Below we explain the steps taking ResNet50 as an example of image classification models. In order to reproduce NLP results, please follow [this guide](https://github.com/PaddlePaddle/benchmark/tree/master/Inference/c%2B%2B/ernie/mkldnn/README.md).
98+
### Prepare dataset
99+
100+
#### Image classification
101+
102+
In order to download the dataset for image classification models benchmarking, execute:
103+
82104
```bash
83105
cd /PATH/TO/PADDLE
84106
python paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py
85107
```
86108
The converted data binary file is saved by default in `$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin`
87-
* ### Prepare model
109+
110+
### Prepare model
111+
112+
#### Image classification
88113
You can run the following commands to download ResNet50 model. The exemplary code snippet provided below downloads a ResNet50 QAT model. The reason for having two different versions of the same model originates from having two different QAT training strategies: One for an non-optimized and second for an optimized graph transform which correspond to QAT1.0 and QAT2.0 respectively.
89114

90115
```bash
91116
mkdir -p /PATH/TO/DOWNLOAD/MODEL/
92117
cd /PATH/TO/DOWNLOAD/MODEL/
93118
# uncomment for QAT1.0 MKL-DNN
94119
# export MODEL_NAME=ResNet50
95-
# export MODEL_FILE_NAME= QAT_models/${MODEL_NAME}_qat_model.tar.gz
120+
# export MODEL_FILE_NAME=QAT_models/${MODEL_NAME}_qat_model.tar.gz
96121
# uncomment for QAT2.0 MKL-DNN
97122
# export MODEL_NAME=resnet50
98-
# export MODEL_FILE_NAME= QAT2_models/${MODEL_NAME}_quant.tar.gz
123+
# export MODEL_FILE_NAME=QAT2_models/${MODEL_NAME}_quant.tar.gz
99124
wget http://paddle-inference-dist.bj.bcebos.com/int8/${MODEL_FILE_NAME}
125+
mkdir ${MODEL_NAME} && tar -xvf ResNet50_qat_model.tar.gz -C ${MODEL_NAME}
100126
```
101127

102-
Unzip the downloaded model to the folder. To verify all the 7 models, you need to set `MODEL_NAME` to one of the following values in command line:
128+
Extract the downloaded model to the folder. To verify all the 7 models, you need to set `MODEL_NAME` to one of the following values in command line:
103129
```text
104130
QAT1.0 models
105131
MODEL_NAME=ResNet50, ResNet101, GoogleNet, MobileNetV1, MobileNetV2, VGG16, VGG19
106132
QAT2.0 models
107-
MODEL_NAME=resnet50, resnet101, mobilenetv1, mobilenetv2, vgg16, vgg19
133+
MODEL_NAME=resnet50, resnet101, mobilenetv1, mobilenetv2, vgg16, vgg19
108134
```
109-
* ### Commands to reproduce benchmark
110-
You can run `qat_int8_comparison.py` with the following arguments to reproduce the accuracy result on ResNet50. The difference of command line between the QAT1.0 MKL-DNN and QAT2.0 MKL-DNN is that we use argument `qat2` to enable QAT2.0 MKL-DNN. To perform QAT2.0 MKL-DNN the performance test, the environmental variable `OMP_NUM_THREADS=1` and `batch_size=1` parameter should be set.
135+
### Commands to reproduce benchmark
136+
137+
#### Image classification
138+
You can use the `qat_int8_image_classification_comparison.py` script to reproduce the accuracy result on ResNet50. The difference between commands usedin the QAT1.0 MKL-DNN and QAT2.0 MKL-DNN is that for QAT2.0 MKL-DNN two additional options are required: the `--qat2` option to enable QAT2.0 MKL-DNN, and the `--quantized_ops` option with a comma-separated list of operators to be quantized. To perform the QAT2.0 MKL-DNN performance test, the environment variable `OMP_NUM_THREADS=1` and `--batch_size=1` option should be set.
111139
>*QAT1.0*
112140
113141
- Accuracy benchmark command on QAT1.0 models
114142

115143
```bash
116144
cd /PATH/TO/PADDLE
117-
OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py --qat_model=/PATH/TO/DOWNLOAD/MODEL/${MODEL_NAME}/model --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.001
145+
OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim/tests/qat_int8_image_classification_comparison.py --qat_model=/PATH/TO/DOWNLOAD/MODEL/${MODEL_NAME}/model --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.001
118146
```
119147
>*QAT2.0*
120148
121149
- Accuracy benchamrk command on QAT2.0 models
122150
```bash
123151
cd /PATH/TO/PADDLE
124-
OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py --qat_model=/PATH/TO/DOWNLOAD/MODEL/${MODEL_NAME} --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.01 --qat2
152+
OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim/tests/qat_int8_image_classification_comparison.py ----qat_model=/PATH/TO/DOWNLOAD/MODEL/${MODEL_NAME}_quant --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.01 --qat2 --quantized_ops="conv2d,pool2d"
125153
```
126154

127155
* Performance benchmark command on QAT2.0 models
128156

129-
```bash
130-
# 1. Save QAT2.0 INT8 model
131-
cd /PATH/TO/PADDLE/build
132-
python ../python/paddle/fluid/contrib/slim/tests/save_qat_model.py --qat_model_path /PATH/TO/DOWNLOAD/MODEL/${QAT2_MODEL_NAME} --int8_model_save_path /PATH/TO/${QAT2_MODEL_NAME}_qat_int8
157+
In order to run performance benchmark, follow the steps below.
133158

134-
# 2. Run the QAT2.0 C-API for performance benchmark
135-
cd /PATH/TO/PADDLE/build
136-
OMP_NUM_THREADS=1 paddle/fluid/inference/tests/api/test_analyzer_qat_image_classification ARGS --enable_fp32=false --with_accuracy_layer=false --int8_model=/PATH/TO/${QAT2_MODEL_NAME}_qat_int8 --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=1 --paddle_num_threads=1
137-
```
159+
1. Save QAT2.0 INT8 model. You can use the script `save_qat_model.py` for this purpose. It also requires the option `--quantized_ops` to indicate which operators are to be quantized.
160+
161+
```bash
162+
cd /PATH/TO/PADDLE/build
163+
python ../python/paddle/fluid/contrib/slim/tests/save_qat_model.py --qat_model_path=/PATH/TO/DOWNLOAD/MODEL/${QAT2_MODEL_NAME} --int8_model_save_path=/PATH/TO/${QAT2_MODEL_NAME}_qat_int8 --quantized_ops="conv2d,pool2d"
164+
```
165+
166+
2. Run the QAT2.0 C-API test for performance benchmark.
167+
168+
```bash
169+
cd /PATH/TO/PADDLE/build
170+
OMP_NUM_THREADS=1 paddle/fluid/inference/tests/api/test_analyzer_qat_image_classification ARGS --enable_fp32=false --with_accuracy_layer=false --int8_model=/PATH/TO/${QAT2_MODEL_NAME}_qat_int8 --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=1 --paddle_num_threads=1
171+
```
138172

139173
> Notes: Due to a large amount of images contained in `int8_full_val.bin` dataset (50 000), the accuracy benchmark which includes comparison of unoptimized and optimized QAT model may last long (even several hours). To accelerate the process, it is recommended to set `OMP_NUM_THREADS` to the max number of physical cores available on the server.

0 commit comments

Comments
 (0)