Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions deploy/slim/quantization/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

# PP-OCR模型量化

复杂的模型有利于提高模型的性能,但也导致模型中存在一定冗余,模型量化将全精度缩减到定点数减少这种冗余,达到减少模型计算复杂度,提高模型推理性能的目的。
模型量化可以在基本不损失模型的精度的情况下,将FP32精度的模型参数转换为Int8精度,减小模型参数大小并加速计算,使用量化后的模型在移动端等部署时更具备速度优势。

Expand All @@ -8,11 +8,12 @@

在开始本教程之前,建议先了解[PaddleOCR模型的训练方法](../../../doc/doc_ch/training.md)以及[PaddleSlim](https://paddleslim.readthedocs.io/zh_CN/latest/index.html)


## 快速开始

量化多适用于轻量模型在移动端的部署,当训练出一个模型后,如果希望进一步的压缩模型大小并加速预测,可使用量化的方法压缩模型。

模型量化主要包括五个步骤:

1. 安装 PaddleSlim
2. 准备训练好的模型
3. 量化训练
Expand All @@ -22,25 +23,27 @@
### 1. 安装PaddleSlim

```bash
pip3 install paddleslim==2.3.2
pip3 install paddleslim
```

### 2. 准备训练好的模型

PaddleOCR提供了一系列训练好的[模型](../../../doc/doc_ch/models_list.md),如果待量化的模型不在列表中,需要按照[常规训练](../../../doc/doc_ch/quickstart.md)方法得到训练好的模型。

### 3. 量化训练
量化训练包括离线量化训练和在线量化训练,在线量化训练效果更好,需加载预训练模型,在定义好量化策略后即可对模型进行量化。

量化训练包括离线量化训练和在线量化训练,在线量化训练效果更好,需加载预训练模型,在定义好量化策略后即可对模型进行量化。

量化训练的代码位于slim/quantization/quant.py 中,比如训练检测模型,以PPOCRv3检测模型为例,训练指令如下:

```
# 下载检测预训练模型:
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
tar xf ch_PP-OCRv3_det_distill_train.tar

python deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model='./ch_PP-OCRv3_det_distill_train/best_accuracy' Global.save_model_dir=./output/quant_model_distill/
```

如果要训练识别模型的量化,修改配置文件和加载的模型参数即可。

### 4. 导出模型
Expand Down
12 changes: 7 additions & 5 deletions deploy/slim/quantization/README_en.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# PP-OCR Models Quantization

Generally, a more complex model would achieve better performance in the task, but it also leads to some redundancy in the model.
Expand All @@ -8,10 +7,12 @@ so as to reduce model calculation complexity and improve model inference perform
This example uses PaddleSlim provided [APIs of Quantization](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/dygraph/quanter/qat.rst) to compress the OCR model.

It is recommended that you could understand following pages before reading this example:

- [The training strategy of OCR model](../../../doc/doc_en/quickstart_en.md)
- [PaddleSlim Document](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/dygraph/quanter/qat.rst)

## Quick Start

Quantization is mostly suitable for the deployment of lightweight models on mobile terminals.
After training, if you want to further compress the model size and accelerate the prediction, you can use quantization methods to compress the model according to the following steps.

Expand All @@ -21,25 +22,25 @@ After training, if you want to further compress the model size and accelerate th
4. Export inference model
5. Deploy quantization inference model


### 1. Install PaddleSlim

```bash
pip3 install paddleslim==2.3.2
pip3 install paddleslim
```


### 2. Download Pre-trained Model

PaddleOCR provides a series of pre-trained [models](../../../doc/doc_en/models_list_en.md).
If the model to be quantified is not in the list, you need to follow the [Regular Training](../../../doc/doc_en/quickstart_en.md) method to get the trained model.


### 3. Quant-Aware Training

Quantization training includes offline quantization training and online quantization training.
Online quantization training is more effective. It is necessary to load the pre-trained model.
After the quantization strategy is defined, the model can be quantified.

The code for quantization training is located in `slim/quantization/quant.py`. For example, the training instructions of slim PPOCRv3 detection model are as follows:

```
# download provided model
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
Expand All @@ -59,6 +60,7 @@ python deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv3/ch_PP
```

### 5. Deploy

The numerical range of the quantized model parameters derived from the above steps is still FP32, but the numerical range of the parameters is int8.
The derived model can be converted through the `opt tool` of PaddleLite.

Expand Down
137 changes: 104 additions & 33 deletions deploy/slim/quantization/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@ def main():
config = load_config(FLAGS.config)
config = merge_config(config, FLAGS.opt)
logger = get_logger()
# build post process

# build dataloader
set_signal_handlers()
valid_dataloader = build_dataloader(config, "Eval", device, logger)

# build post process
post_process_class = build_post_process(config["PostProcess"], config["Global"])

# build model
Expand All @@ -81,19 +85,14 @@ def main():
if (
config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
): # for multi head
out_channels_list = {}
if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
char_num = char_num - 2
# update SARLoss params
assert (
list(config["Loss"]["loss_config_list"][-1].keys())[0]
== "DistillationSARLoss"
)
config["Loss"]["loss_config_list"][-1]["DistillationSARLoss"][
"ignore_index"
] = (char_num + 1)
out_channels_list = {}
if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
char_num = char_num - 3
out_channels_list["CTCLabelDecode"] = char_num
out_channels_list["SARLabelDecode"] = char_num + 2
out_channels_list["NRTRLabelDecode"] = char_num + 3
config["Architecture"]["Models"][key]["Head"][
"out_channels_list"
] = out_channels_list
Expand All @@ -102,48 +101,109 @@ def main():
"out_channels"
] = char_num
elif config["Architecture"]["Head"]["name"] == "MultiHead": # for multi head
out_channels_list = {}
if config["PostProcess"]["name"] == "SARLabelDecode":
char_num = char_num - 2
# update SARLoss params
assert list(config["Loss"]["loss_config_list"][1].keys())[0] == "SARLoss"
if config["Loss"]["loss_config_list"][1]["SARLoss"] is None:
config["Loss"]["loss_config_list"][1]["SARLoss"] = {
"ignore_index": char_num + 1
}
else:
config["Loss"]["loss_config_list"][1]["SARLoss"]["ignore_index"] = (
char_num + 1
)
out_channels_list = {}
if config["PostProcess"]["name"] == "NRTRLabelDecode":
char_num = char_num - 3
out_channels_list["CTCLabelDecode"] = char_num
out_channels_list["SARLabelDecode"] = char_num + 2
out_channels_list["NRTRLabelDecode"] = char_num + 3
config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num

if config["PostProcess"]["name"] == "SARLabelDecode": # for SAR model
config["Loss"]["ignore_index"] = char_num - 1

model = build_model(config["Architecture"])
extra_input_models = [
"SRN",
"NRTR",
"SAR",
"SEED",
"SVTR",
"SVTR_LCNet",
"VisionLAN",
"RobustScanner",
"SVTR_HGNet",
]
extra_input = False
if config["Architecture"]["algorithm"] == "Distillation":
for key in config["Architecture"]["Models"]:
extra_input = (
extra_input
or config["Architecture"]["Models"][key]["algorithm"]
in extra_input_models
)
else:
extra_input = config["Architecture"]["algorithm"] in extra_input_models
if "model_type" in config["Architecture"].keys():
if config["Architecture"]["algorithm"] == "CAN":
model_type = "can"
elif config["Architecture"]["algorithm"] == "LaTeXOCR":
model_type = "latexocr"
config["Metric"]["cal_bleu_score"] = True
elif config["Architecture"]["algorithm"] == "UniMERNet":
model_type = "unimernet"
config["Metric"]["cal_bleu_score"] = True
elif config["Architecture"]["algorithm"] in [
"PP-FormulaNet-S",
"PP-FormulaNet-L",
]:
model_type = "pp_formulanet"
config["Metric"]["cal_bleu_score"] = True
else:
model_type = config["Architecture"]["model_type"]
else:
model_type = None

# get QAT model
quanter = QAT(config=quant_config)
quanter.quantize(model)

load_model(config, model)

# build metric
eval_class = build_metric(config["Metric"])
# amp
use_amp = config["Global"].get("use_amp", False)
amp_level = config["Global"].get("amp_level", "O2")
amp_custom_black_list = config["Global"].get("amp_custom_black_list", [])
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
"FLAGS_cudnn_batchnorm_spatial_persistent": 1,
}
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["Global"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["Global"].get(
"use_dynamic_loss_scaling", False
)
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling,
)
if amp_level == "O2":
model = paddle.amp.decorate(
models=model, level=amp_level, master_weight=True
)
else:
scaler = None

# build dataloader
set_signal_handlers()
valid_dataloader = build_dataloader(config, "Eval", device, logger)
best_model_dict = load_model(
config, model, model_type=config["Architecture"]["model_type"]
)
if len(best_model_dict):
logger.info("metric in ckpt ***************")
for k, v in best_model_dict.items():
logger.info("{}:{}".format(k, v))

use_srn = config["Architecture"]["algorithm"] == "SRN"
model_type = config["Architecture"].get("model_type", None)
# start eval
metric = program.eval(
model, valid_dataloader, post_process_class, eval_class, model_type, use_srn
model,
valid_dataloader,
post_process_class,
eval_class,
model_type,
extra_input,
scaler,
amp_level,
amp_custom_black_list,
)
model.eval()

Expand Down Expand Up @@ -176,12 +236,23 @@ def main():
archs[idx],
sub_model_save_path,
logger,
None,
config,
input_shape,
quanter,
)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(model, arch_config, save_path, logger, input_shape, quanter)
export_single_model(
model,
arch_config,
save_path,
logger,
None,
config,
input_shape,
quanter,
)


if __name__ == "__main__":
Expand Down
Loading