Skip to content

Commit a16b934

Browse files
authored
Merge pull request #410 from yinhaofeng/inference_python
add inference python
2 parents 779f512 + 2b1ea65 commit a16b934

File tree

11 files changed

+531
-3
lines changed

11 files changed

+531
-3
lines changed

doc/inference.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Paddle Inference的使用方法
2+
paddlerec目前提供在静态图训练时使用save_inference_model接口保存模型,以及将保存的模型使用Inference预测库进行服务端部署的功能。本教程将以wide_deep模型为例,说明如何使用这两项功能。
3+
4+
## 使用save_inference_model接口保存模型
5+
在服务器端使用python部署需要先使用save_inference_model接口保存模型。
6+
1. 首先需要在模型的yaml配置中,加入use_inference参数,并把值设置成True。use_inference决定是否使用save_inference_model接口保存模型,默认为否。若使用save_inference_model接口保存模型,保存下来的模型支持使用Paddle Inference的方法预测,但不支持直接使用paddlerec原生的的预测方法加载模型。
7+
2. 确定需要的输入和输出的预测模型变量,将其变量名以字符串的形式填入save_inference_feed_varnames和save_inference_fetch_varnames列表中。
8+
以wide_deep模型为例,可以在其config.yaml文件中观察到如下结构。训练及测试数据集选用[Display Advertising Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge/)所用的Criteo数据集。该数据集包括两部分:训练集和测试集。训练集包含一段时间内Criteo的部分流量,测试集则对应训练数据后一天的广告点击流量。feed参数的名字中```<label>```表示广告是否被点击,点击用1表示,未点击用0表示,```<integer feature>```代表数值特征(连续特征dense_input),共有13个连续特征,```<categorical feature>```代表分类特征(离散特征C1~C26),共有26个离散特征。fetch参数输出的是auc,具体意义为static_model.py里def net()函数中将auc使用cast转换为float32类型语句中的cast算子。
9+
```yaml
10+
runner:
11+
# 通用配置不再赘述
12+
...
13+
# use inference save model
14+
use_inference: True # 静态图训练时保存为inference model
15+
save_inference_feed_varnames: ["label","C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"] # inference model 的feed参数的名字
16+
save_inference_fetch_varnames: ["cast_0.tmp_0"] # inference model 的fetch参数的名字
17+
```
18+
3. 启动静态图训练
19+
```bash
20+
# 进入模型目录
21+
# cd models/rank/wide_deep # 在任意目录均可运行
22+
# 静态图训练
23+
python -u ../../../tools/static_trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
24+
```
25+
26+
## 将保存的模型使用Inference预测库进行服务端部署
27+
paddlerec提供tools/paddle_infer.py脚本,供您方便的使用inference预测库高效的对模型进行预测。
28+
29+
需要安装的库:
30+
```bash
31+
pip install pynvml
32+
pip install psutil
33+
pip install GPUtil
34+
```
35+
36+
1. 启动paddle_infer.py脚本的参数:
37+
38+
| 名称 | 类型 | 取值 | 是否必须 | 作用描述 |
39+
| :-----------------: | :-------: | :--------------------------: | :-----: | :------------------------------------------------------------------: |
40+
| --model_file | string | 任意路径 || 模型文件路径(当需要从磁盘加载 Combined 模型时使用) |
41+
| --params_file | string | 任意路径 || 参数文件路径 (当需要从磁盘加载 Combined 模型时使用) |
42+
| --model_dir | string | 任意路径 || 模型文件夹路径 (当需要从磁盘加载非 Combined 模型时使用) |
43+
| --use_gpu | bool | True/False || 是否使用gpu |
44+
| --data_dir | string | 任意路径 || 测试数据目录 |
45+
| --reader_file | string | 任意路径 || 测试时用的Reader()所在python文件地址 |
46+
| --batchsize | int | >= 1 || 批训练样本数量 |
47+
| --model_name | str | 任意名字 || 输出模型名字 |
48+
49+
2. 以wide_deep模型的demo数据为例,启动预测:
50+
```bash
51+
# 进入模型目录
52+
# cd models/rank/wide_deep # 在任意目录均可运行
53+
python -u ../../../tools/paddle_infer.py --model_file=output_model_wide_deep/2/rec_inference.pdmodel --params_file=output_model_wide_deep/2/rec_inference.pdiparams --use_gpu=False --data_dir=data/sample_data/train --reader_file=criteo_reader.py --batchsize=5
54+
```

doc/yaml.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
| print_interval | int | >= 1 || 训练指标打印batch间隔 |
2222
| use_auc | bool | True/False || 在每个epoch开始时重置auc指标的值 |
2323
| use_visual | bool | True/False || 开启模型训练的可视化功能,开启时需要安装visualDL |
24+
| use_inference | bool | True/False || 是否使用save_inference_model接口保存 |
25+
| save_inference_feed_varnames | list[string] | 组网中指定Variable的name || 预测模型的入口变量name |
26+
| save_inference_fetch_varnames | list[string] | 组网中指定Variable的name || 预测模型的出口变量name |
2427

2528

2629
## hyper_parameters变量
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
feed_var {
16+
name: "movieid"
17+
alias_name: "movieid"
18+
is_lod_tensor: true
19+
feed_type: 0
20+
shape: -1
21+
}
22+
feed_var {
23+
name: "title"
24+
alias_name: "title"
25+
is_lod_tensor: true
26+
feed_type: 0
27+
shape: -1
28+
}
29+
feed_var {
30+
name: "genres"
31+
alias_name: "genres"
32+
is_lod_tensor: true
33+
feed_type: 0
34+
shape: -1
35+
}
36+
fetch_var {
37+
name: "save_infer_model/scale_0.tmp_0"
38+
alias_name: "save_infer_model/scale_0.tmp_0"
39+
is_lod_tensor: false
40+
fetch_type: 1
41+
shape: 32
42+
}
43+
"""
44+
45+
from paddle_serving_app.local_predict import LocalPredictor
46+
import redis
47+
import numpy as np
48+
import codecs
49+
50+
51+
class Movie(object):
52+
def __init__(self):
53+
self.movie_id, self.title, self.genres = "", "", ""
54+
pass
55+
56+
57+
def hash2(a):
58+
return hash(a) % 600000
59+
60+
61+
ctr_client = LocalPredictor()
62+
ctr_client.load_model_config("serving_server")
63+
with codecs.open("movies.dat", "r", encoding='utf-8', errors='ignore') as f:
64+
lines = f.readlines()
65+
66+
ff = open("movie_vectors.txt", 'w')
67+
68+
for line in lines:
69+
if len(line.strip()) == 0:
70+
continue
71+
tmp = line.strip().split("::")
72+
movie_id = tmp[0]
73+
title = tmp[1]
74+
genre_group = tmp[2]
75+
76+
tmp = genre_group.strip().split("|")
77+
genre = tmp
78+
movie = Movie()
79+
item_infos = []
80+
if isinstance(genre, list):
81+
movie.genres = genre
82+
else:
83+
movie.genres = [genre]
84+
movie.movie_id, movie.title = movie_id, title
85+
item_infos.append(movie)
86+
87+
dic = {"movieid": [], "title": [], "genres": []}
88+
batch_size = len(item_infos)
89+
for i, item_info in enumerate(item_infos):
90+
dic["movieid"].append(hash2(item_info.movie_id))
91+
dic["title"].append(hash2(item_info.title))
92+
dic["genres"].extend([hash2(x) for x in item_info.genres])
93+
94+
if len(dic["title"]) <= 4:
95+
for i in range(4 - len(dic["title"])):
96+
dic["title"].append("0")
97+
dic["title"] = dic["title"][:4]
98+
if len(dic["genres"]) <= 3:
99+
for i in range(3 - len(dic["genres"])):
100+
dic["genres"].append("0")
101+
dic["genres"] = dic["genres"][:3]
102+
103+
dic["movieid"] = np.array(dic["movieid"]).astype(np.int64).reshape(-1, 1)
104+
dic["title"] = np.array(dic["title"]).astype(np.int64).reshape(-1, 4)
105+
dic["genres"] = np.array(dic["genres"]).astype(np.int64).reshape(-1, 3)
106+
107+
fetch_map = ctr_client.predict(
108+
feed=dic, fetch=["save_infer_model/scale_0.tmp_0"], batch=True)
109+
ff.write("{}:{}\n".format(movie_id,
110+
str(fetch_map["save_infer_model/scale_0.tmp_0"]
111+
.tolist()[0])))
112+
ff.close()
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
runner:
16+
train_data_dir: "../data/train"
17+
train_reader_path: "reader" # importlib format
18+
train_batch_size: 1
19+
model_save_path: "movie_model"
20+
21+
use_gpu: True
22+
epochs: 5
23+
print_interval: 20
24+
25+
test_data_dir: "../data/test"
26+
infer_reader_path: "reader" # importlib format
27+
infer_batch_size: 1
28+
infer_load_path: "movie_model"
29+
infer_start_epoch: 4
30+
infer_end_epoch: 5
31+
32+
runner_result_dump_path: "recall_infer_result"
33+
34+
#use inference save model
35+
use_inference: True
36+
save_inference_feed_varnames: ["movieid", "title", "genres"]
37+
save_inference_fetch_varnames: ["linear_15.tmp_1"]
38+
39+
# hyper parameters of user-defined network
40+
hyper_parameters:
41+
# optimizer config
42+
optimizer:
43+
class: Adam
44+
learning_rate: 0.001
45+
# user-defined <key, value> pairs
46+
sparse_feature_number: 600000
47+
sparse_feature_dim: 9
48+
dense_input_dim: 13
49+
fc_sizes: [512, 256, 128, 32]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
runner:
16+
train_data_dir: "../data/train"
17+
train_reader_path: "reader" # importlib format
18+
train_batch_size: 1
19+
model_save_path: "user_model"
20+
21+
use_gpu: True
22+
epochs: 5
23+
print_interval: 20
24+
25+
test_data_dir: "../data/test"
26+
infer_reader_path: "reader" # importlib format
27+
infer_batch_size: 1
28+
infer_load_path: "user_model"
29+
infer_start_epoch: 4
30+
infer_end_epoch: 5
31+
32+
runner_result_dump_path: "recall_infer_result"
33+
34+
#use inference save model
35+
use_inference: True
36+
save_inference_feed_varnames: ["userid", "gender", "age", "occupation"]
37+
save_inference_fetch_varnames: ["linear_11.tmp_1"]
38+
39+
# hyper parameters of user-defined network
40+
hyper_parameters:
41+
# optimizer config
42+
optimizer:
43+
class: Adam
44+
learning_rate: 0.001
45+
# user-defined <key, value> pairs
46+
sparse_feature_number: 600000
47+
sparse_feature_dim: 9
48+
dense_input_dim: 13
49+
fc_sizes: [512, 256, 128, 32]

models/rank/wide_deep/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ wide&deep设计了一种融合浅层(wide)模型和深层(deep)模型进
8888

8989
| 模型 | auc | batch_size | thread_num| epoch_num| Time of each epoch |
9090
| :------| :------ | :------| :------ | :------| :------ |
91-
| wide_deep | 0.82 | 512 | 1 | 4 | 约2小时 |
91+
| wide_deep | 0.79 | 512 | 1 | 4 | 约2小时 |
9292

9393
1. 确认您当前所在目录为PaddleRec/models/rank/wide_deep
9494
2. 进入paddlerec/datasets/criteo目录下,执行该脚本,会从国内源的服务器上下载我们预处理完成的criteo全量数据集,并解压到指定文件夹。

models/rank/wide_deep/config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@ runner:
2828
infer_reader_path: "criteo_reader" # importlib format
2929
infer_batch_size: 5
3030
infer_load_path: "output_model_wide_deep"
31-
infer_start_epoch: 0
31+
infer_start_epoch: 2
3232
infer_end_epoch: 3
33+
#use inference save model
34+
use_inference: False
35+
save_inference_feed_varnames: ["label","C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"]
36+
save_inference_fetch_varnames: ["cast_0.tmp_0"]
3337

3438
# hyper parameters of user-defined network
3539
hyper_parameters:

models/rank/wide_deep/static_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def net(self, input, is_infer=False):
8484
label=self.label_input,
8585
num_thresholds=2**12,
8686
slide_steps=20)
87+
auc = paddle.cast(auc, "float32")
8788
self.inference_target_var = auc
8889
if is_infer:
8990
fetch_dict = {'auc': auc}

0 commit comments

Comments
 (0)