Skip to content

Commit 0373178

Browse files
committed
add inference python
1 parent 50e5368 commit 0373178

File tree

7 files changed

+222
-11
lines changed

7 files changed

+222
-11
lines changed

doc/inference.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Paddle Inference的使用方法
2+
paddlerec目前提供在静态图训练时使用save_inference_model接口保存模型,以及将保存的模型使用Inference预测库进行服务端部署的功能。本教程将以wide_deep模型为例,说明如何使用这两项功能。
3+
4+
## 使用save_inference_model接口保存模型
5+
在服务器端使用python部署需要先使用save_inference_model接口保存模型。
6+
1. 首先需要在模型的yaml配置中,加入use_inference参数。use_inference决定是否使用save_inference_model接口保存模型,默认为否。若使用save_inference_model接口保存模型,保存下来的模型支持使用Paddle Inference的方法预测,但不支持直接使用paddlerec原生的的预测方法加载模型。
7+
2. 确定需要的输入和输出的预测模型变量,将其变量名以字符串的形式填入save_inference_feed_varnames和save_inference_fetch_varnames列表中。
8+
以wide_deep模型为例,可以在其config.yaml文件中观察到如下结构。训练及测试数据集选用[Display Advertising Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge/)所用的Criteo数据集。该数据集包括两部分:训练集和测试集。训练集包含一段时间内Criteo的部分流量,测试集则对应训练数据后一天的广告点击流量。feed参数的名字中```<label>```表示广告是否被点击,点击用1表示,未点击用0表示,```<integer feature>```代表数值特征(连续特征dense_input),共有13个连续特征,```<categorical feature>```代表分类特征(离散特征C1~C26),共有26个离散特征。fetch参数输出的是auc,具体意义为static_model.py里def net()函数中将auc使用cast转换为float32类型语句中的cast算子。
9+
```yaml
10+
runner:
11+
# 通用配置不再赘述
12+
...
13+
# use inference save model
14+
use_inference: True # 静态图训练时保存为inference model
15+
save_inference_feed_varnames: ["label","C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"] # inference model 的feed参数的名字
16+
save_inference_fetch_varnames: ["cast_0.tmp_0"] # inference model 的fetch参数的名字
17+
```
18+
3. 启动静态图训练
19+
```bash
20+
# 进入模型目录
21+
# cd models/rank/wide_deep # 在任意目录均可运行
22+
# 静态图训练
23+
python -u ../../../tools/static_trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
24+
```
25+
26+
## 将保存的模型使用Inference预测库进行服务端部署
27+
paddlerec提供tools/paddle_infer.py脚本,供您方便的使用inference预测库高效的对模型进行预测。
28+
29+
1. 启动paddle_infer.py脚本的参数:
30+
31+
| 名称 | 类型 | 取值 | 是否必须 | 作用描述 |
32+
| :-----------------: | :-------: | :--------------------------: | :-----: | :------------------------------------------------------------------: |
33+
| --model_file | string | 任意路径 || 模型文件路径(当需要从磁盘加载 Combined 模型时使用) |
34+
| --params_file | string | 任意路径 || 参数文件路径 (当需要从磁盘加载 Combined 模型时使用) |
35+
| --model_dir | string | 任意路径 || 模型文件夹路径 (当需要从磁盘加载非 Combined 模型时使用) |
36+
| --use_gpu | bool | True/False || 是否使用gpu |
37+
| --data_dir | string | 任意路径 || 测试数据目录 |
38+
| --reader_file | string | 任意路径 || 测试时用的Reader()所在python文件地址 |
39+
| --batchsize | int | >= 1 || 批训练样本数量 |
40+
41+
2. 以wide_deep模型的demo数据为例,启动预测:
42+
```bash
43+
# 进入模型目录
44+
# cd models/rank/wide_deep # 在任意目录均可运行
45+
python -u ../../../tools/paddle_infer.py --model_file=output_model_wide_deep/2/rec_inference.pdmodel --params_file=output_model_wide_deep/2/rec_inference.pdiparams --use_gpu=False --data_dir=data/sample_data/train --reader_file=criteo_reader.py --batchsize=5
46+
```

doc/yaml.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
| print_interval | int | >= 1 || 训练指标打印batch间隔 |
2222
| use_auc | bool | True/False || 在每个epoch开始时重置auc指标的值 |
2323
| use_visual | bool | True/False || 开启模型训练的可视化功能,开启时需要安装visualDL |
24+
| use_inference | bool | True/False || 是否使用save_inference_model接口保存 |
25+
| save_inference_feed_varnames | list[string] | 组网中指定Variable的name || 预测模型的入口变量name |
26+
| save_inference_fetch_varnames | list[string] | 组网中指定Variable的name || 预测模型的出口变量name |
2427

2528

2629
## hyper_parameters变量

models/rank/wide_deep/config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@ runner:
2828
infer_reader_path: "criteo_reader" # importlib format
2929
infer_batch_size: 5
3030
infer_load_path: "output_model_wide_deep"
31-
infer_start_epoch: 0
31+
infer_start_epoch: 2
3232
infer_end_epoch: 3
33+
#use inference save model
34+
use_inference: True
35+
save_inference_feed_varnames: ["label","C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"]
36+
save_inference_fetch_varnames: ["cast_0.tmp_0"]
3337

3438
# hyper parameters of user-defined network
3539
hyper_parameters:

models/rank/wide_deep/static_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def net(self, input, is_infer=False):
8484
label=self.label_input,
8585
num_thresholds=2**12,
8686
slide_steps=20)
87+
auc = paddle.cast(auc, "float32")
8788
self.inference_target_var = auc
8889
if is_infer:
8990
fetch_dict = {'auc': auc}

tools/paddle_infer.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import os
17+
import paddle.nn as nn
18+
import numpy as np
19+
import time
20+
import logging
21+
import sys
22+
from importlib import import_module
23+
__dir__ = os.path.dirname(os.path.abspath(__file__))
24+
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
25+
from utils.utils_single import load_yaml, load_dy_model_class, get_abs_model
26+
from utils.save_load import save_model, load_model
27+
from paddle.io import DistributedBatchSampler, DataLoader
28+
import argparse
29+
from paddle.inference import Config
30+
from paddle.inference import create_predictor
31+
32+
33+
def parse_args():
34+
parser = argparse.ArgumentParser()
35+
parser.add_argument("--model_file", type=str)
36+
parser.add_argument("--params_file", type=str)
37+
parser.add_argument("--model_dir", type=str)
38+
parser.add_argument("--use_gpu", type=bool)
39+
parser.add_argument("--data_dir", type=str)
40+
parser.add_argument("--reader_file", type=str)
41+
parser.add_argument("--batchsize", type=int)
42+
args = parser.parse_args()
43+
return args
44+
45+
46+
def init_predictor(args):
47+
if args.model_dir:
48+
config = Config(args.model_dir)
49+
else:
50+
config = Config(args.model_file, args.params_file)
51+
52+
if args.use_gpu:
53+
config.enable_use_gpu(1000, 0)
54+
else:
55+
config.disable_gpu()
56+
predictor = create_predictor(config)
57+
return predictor
58+
59+
60+
def create_data_loader(args):
61+
data_dir = args.data_dir
62+
reader_file = args.reader_file.split(".")[0]
63+
batchsize = args.batchsize
64+
place = args.place
65+
file_list = [os.path.join(data_dir, x) for x in os.listdir(data_dir)]
66+
sys.path.append(os.path.abspath("."))
67+
reader_class = import_module(reader_file)
68+
dataset = reader_class.RecDataset(file_list, config=None)
69+
loader = DataLoader(
70+
dataset, batch_size=batchsize, places=place, drop_last=True)
71+
return loader
72+
73+
74+
def main(args):
75+
predictor = init_predictor(args)
76+
place = paddle.set_device('gpu' if args.use_gpu else 'cpu')
77+
args.place = place
78+
input_names = predictor.get_input_names()
79+
output_names = predictor.get_output_names()
80+
test_dataloader = create_data_loader(args)
81+
for batch_id, batch_data in enumerate(test_dataloader):
82+
name_data_pair = dict(zip(input_names, batch_data))
83+
for name in input_names:
84+
input_tensor = predictor.get_input_handle(name)
85+
input_tensor.copy_from_cpu(name_data_pair[name].numpy())
86+
predictor.run()
87+
results = []
88+
for name in output_names:
89+
output_tensor = predictor.get_output_handle(name)
90+
output_data = output_tensor.copy_to_cpu()[0]
91+
results.append(output_data)
92+
print(results)
93+
94+
95+
if __name__ == '__main__':
96+
args = parse_args()
97+
main(args)

tools/static_trainer.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from utils.static_ps.reader_helper import get_reader
2626
from utils.utils_single import load_yaml, load_static_model_class, get_abs_model, create_data_loader, reset_auc
27-
from utils.save_load import save_static_model
27+
from utils.save_load import save_static_model, save_inference_model
2828

2929
import time
3030
import argparse
@@ -36,6 +36,7 @@
3636
def parse_args():
3737
parser = argparse.ArgumentParser("PaddleRec train static script")
3838
parser.add_argument("-m", "--config_yaml", type=str)
39+
parser.add_argument("-o", "--opt", nargs='*', type=str)
3940
args = parser.parse_args()
4041
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
4142
args.config_yaml = get_abs_model(args.config_yaml)
@@ -49,6 +50,12 @@ def main(args):
4950
config = load_yaml(args.config_yaml)
5051
config["yaml_path"] = args.config_yaml
5152
config["config_abs_dir"] = args.abs_dir
53+
# modify config from command
54+
if args.opt:
55+
for parameter in args.opt:
56+
parameter = parameter.strip()
57+
key, value = parameter.split("=")
58+
config[key] = value
5259
# load static model class
5360
static_model_class = load_static_model_class(config)
5461

@@ -63,6 +70,7 @@ def main(args):
6370
use_gpu = config.get("runner.use_gpu", True)
6471
use_auc = config.get("runner.use_auc", False)
6572
use_visual = config.get("runner.use_visual", False)
73+
use_inference = config.get("runner.use_inference", False)
6674
auc_num = config.get("runner.auc_num", 1)
6775
train_data_dir = config.get("runner.train_data_dir", None)
6876
epochs = config.get("runner.epochs", None)
@@ -74,9 +82,9 @@ def main(args):
7482
os.environ["CPU_NUM"] = str(config.get("runner.thread_num", 1))
7583
logger.info("**************common.configs**********")
7684
logger.info(
77-
"use_gpu: {}, use_visual: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
78-
format(use_gpu, use_visual, train_data_dir, epochs, print_interval,
79-
model_save_path))
85+
"use_gpu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
86+
format(use_gpu, use_visual, batch_size, train_data_dir, epochs,
87+
print_interval, model_save_path))
8088
logger.info("**************common.configs**********")
8189

8290
place = paddle.set_device('gpu' if use_gpu else 'cpu')
@@ -124,11 +132,44 @@ def main(args):
124132
else:
125133
logger.info("reader type wrong")
126134

127-
save_static_model(
128-
paddle.static.default_main_program(),
129-
model_save_path,
130-
epoch_id,
131-
prefix='rec_static')
135+
if use_inference:
136+
feed_var_names = config.get("runner.save_inference_feed_varnames",
137+
[])
138+
feedvars = []
139+
fetch_var_names = config.get(
140+
"runner.save_inference_fetch_varnames", [])
141+
fetchvars = []
142+
for var_name in feed_var_names:
143+
if var_name not in paddle.static.default_main_program(
144+
).global_block().vars:
145+
raise ValueError(
146+
"Feed variable: {} not in default_main_program, global block has follow vars: {}".
147+
format(var_name,
148+
paddle.static.default_main_program()
149+
.global_block().vars.keys()))
150+
else:
151+
feedvars.append(paddle.static.default_main_program()
152+
.global_block().vars[var_name])
153+
for var_name in fetch_var_names:
154+
if var_name not in paddle.static.default_main_program(
155+
).global_block().vars:
156+
raise ValueError(
157+
"Fetch variable: {} not in default_main_program, global block has follow vars: {}".
158+
format(var_name,
159+
paddle.static.default_main_program()
160+
.global_block().vars.keys()))
161+
else:
162+
fetchvars.append(paddle.static.default_main_program()
163+
.global_block().vars[var_name])
164+
165+
save_inference_model(model_save_path, epoch_id, feedvars,
166+
fetchvars, exe)
167+
else:
168+
save_static_model(
169+
paddle.static.default_main_program(),
170+
model_save_path,
171+
epoch_id,
172+
prefix='rec_static')
132173

133174

134175
def dataset_train(epoch_id, dataset, fetch_vars, exe, config):
@@ -179,7 +220,7 @@ def dataloader_train(epoch_id, train_dataloader, input_data_names, fetch_vars,
179220
logger.info(
180221
"epoch: {}, batch_id: {}, ".format(epoch_id,
181222
batch_id) + metric_str +
182-
"avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
223+
"avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} ins/s".
183224
format(train_reader_cost / print_interval, (
184225
train_reader_cost + train_run_cost) / print_interval,
185226
total_samples / print_interval, total_samples / (

tools/utils/save_load.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,25 @@ def save_static_model(program, model_path, epoch_id, prefix='rec_static'):
6161
logger.info("Already save model in {}".format(model_path))
6262

6363

64+
def save_inference_model(model_path,
65+
epoch_id,
66+
feed_vars,
67+
fetch_vars,
68+
exe,
69+
prefix='rec_inference'):
70+
"""
71+
save inference model to target path
72+
"""
73+
model_path = os.path.join(model_path, str(epoch_id))
74+
_mkdir_if_not_exist(model_path)
75+
model_prefix = os.path.join(model_path, prefix)
76+
paddle.static.save_inference_model(
77+
path_prefix=model_prefix,
78+
feed_vars=feed_vars,
79+
fetch_vars=fetch_vars,
80+
executor=exe)
81+
82+
6483
def load_static_model(program, model_path, prefix='rec_static'):
6584
logger.info("start load model from {}".format(model_path))
6685
model_prefix = os.path.join(model_path, prefix)

0 commit comments

Comments
 (0)