Skip to content

Commit e56ccd4

Browse files
authored
Merge pull request #424 from yinhaofeng/dygraph_save
add dygraph to static
2 parents 9aecccf + 226f0a2 commit e56ccd4

File tree

5 files changed

+54
-10
lines changed

5 files changed

+54
-10
lines changed

doc/inference.md

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Paddle Inference的使用方法
2-
paddlerec目前提供在静态图训练时使用save_inference_model接口保存模型,以及将保存的模型使用Inference预测库进行服务端部署的功能。本教程将以wide_deep模型为例,说明如何使用这两项功能
2+
paddlerec目前提供在静态图训练时使用save_inference_model接口保存模型,动态图训练后将保存的模型转化为静态图的样式,以及将保存的模型使用Inference预测库进行服务端部署的功能。本教程将以wide_deep模型为例,说明如何使用这三项功能
33

44
## 使用save_inference_model接口保存模型
55
在服务器端使用python部署需要先使用save_inference_model接口保存模型。
@@ -12,8 +12,8 @@ runner:
1212
...
1313
# use inference save model
1414
use_inference: True # 静态图训练时保存为inference model
15-
save_inference_feed_varnames: ["label","C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"] # inference model 的feed参数的名字
16-
save_inference_fetch_varnames: ["cast_0.tmp_0"] # inference model 的fetch参数的名字
15+
save_inference_feed_varnames: ["C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"] # inference model 的feed参数的名字
16+
save_inference_fetch_varnames: ["sigmoid_0.tmp_0"] # inference model 的fetch参数的名字
1717
```
1818
3. 启动静态图训练
1919
```bash
@@ -23,6 +23,39 @@ runner:
2323
python -u ../../../tools/static_trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
2424
```
2525

26+
## 使用to_static.py脚本转化动态图保存下来的模型
27+
若您在使用动态图训练完成,希望将保存下来的模型转化为静态图inference,那么可以参考我们提供的to_static.py脚本。
28+
1. 首先正常使用动态图训练保存参数
29+
```bash
30+
# 进入模型目录
31+
# cd models/rank/wide_deep # 在任意目录均可运行
32+
# 动态图训练
33+
python -u ../../../tools/trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
34+
```
35+
2. 打开yaml配置,增加`model_init_path`选项
36+
to_static.py脚本会先加载`model_init_path`地址处的模型,然后再转化为静态图保存。注意不要在一开始训练时就打开这个选项,不然会变成热启动训练。
37+
3. 更改to_static脚本,根据您的模型需求改写其中to_static语句。
38+
我们以wide_deep模型为例,在wide_deep模型的组网中,需要保存前向forward的部分,具体代码可参考[net.py](https://github.com/PaddlePaddle/PaddleRec/blob/master/models/rank/wide_deep/net.py)。其输入参数为26个离散特征组成的list,以及1个连续特征。离散特征的shape统一为(batchsize,1)类型为int64,连续特征的shape为(batchsize,13)类型为float32。
39+
所以我们在to_static脚本中的paddle.jit.to_static语句中指定input_spec如下所示。input_spec的详细用法:[InputSpec 功能介绍](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/04_dygraph_to_static/input_spec_cn.html)
40+
```python
41+
# example dnn and wide_deep model forward
42+
dy_model = paddle.jit.to_static(dy_model,
43+
input_spec=[[paddle.static.InputSpec(shape=[None, 1], dtype='int64') for jj in range(26)], paddle.static.InputSpec(shape=[None, 13], dtype='float32')])
44+
```
45+
4. 运行to_static脚本, 参数为您的yaml文件,即可保存成功。将您在yaml文件中指定的model_init_path路径下的参数,转换并保存到model_save_path/(infer_end_epoch-1)目录下。
46+
注:infer_end_epoch-1是因为epoch从0开始计数,如运行3个epoch即0~2
47+
```bash
48+
python -u ../../../tools/to_static.py -m config.yaml
49+
```
50+
5. 我们在使用inference预测库预测时也需要根据输入和输出做出对应的调整。比如我们保存的模型为wide_deep模型的组网中,前向forward的部分。输入为26个离散特征组成的list以及1个连续特征,输出为prediction预测值。所以我们在使用inference预测库预测时也需要将输入和输出做出对应的调整。
51+
将criteo_reader.py输入数据中的label部分去除:
52+
```python
53+
# 无需改动部分不再赘述
54+
# 在最后输出的list中,去除第一个np.array,即label部分。
55+
yield output_list[1:]
56+
```
57+
将inference预测得到的prediction预测值和数据集中的label对比,使用另外的脚本计算auc指标即可。
58+
2659
## 将保存的模型使用Inference预测库进行服务端部署
2760
paddlerec提供tools/paddle_infer.py脚本,供您方便的使用inference预测库高效的对模型进行预测。
2861

doc/serving.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ runner:
1212
...
1313
# use inference save model
1414
use_inference: True # 静态图训练时保存为inference model
15-
save_inference_feed_varnames: ["label","C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"] # inference model 的feed参数的名字
16-
save_inference_fetch_varnames: ["cast_0.tmp_0"] # inference model 的fetch参数的名字
15+
save_inference_feed_varnames: ["C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"] # inference model 的feed参数的名字
16+
save_inference_fetch_varnames: ["sigmoid_0.tmp_0"] # inference model 的fetch参数的名字
1717
```
1818
3. 启动静态图训练
1919
```bash
@@ -102,6 +102,15 @@ python ../../../tools/webserver.py gpu 9393
102102
# CPU
103103
python ../../../tools/webserver.py cpu 9393
104104
```
105+
### 调整reader
106+
我们在服务端底层使用Inference预测库预测。和直接使用Inference预测库一样,需要在reader中将输入和输出做出对应的调整。比如我们保存的模型为wide_deep模型的组网中。输入为26个离散特征组成的list以及1个连续特征,输出为prediction预测值。
107+
将criteo_reader.py输入数据中的label部分去除:
108+
```python
109+
# 无需改动部分不再赘述
110+
# 在最后输出的list中,去除第一个np.array,即label部分。
111+
yield output_list[1:]
112+
```
113+
将预测得到的prediction预测值和数据集中的label对比,使用另外的脚本计算auc指标即可。
105114

106115
## 测试部署的服务
107116
在服务器端启动serving服务成功后,部署客户端需要您打开新的终端页面。

models/rank/wide_deep/config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ runner:
2222
train_batch_size: 2
2323
epochs: 3
2424
print_interval: 2
25-
#model_init_path: "output_model/0" # init model
25+
# model_init_path: "output_model_wide_deep/2" # init model
2626
model_save_path: "output_model_wide_deep"
2727
test_data_dir: "data/sample_data/train"
2828
infer_reader_path: "criteo_reader" # importlib format
@@ -32,8 +32,8 @@ runner:
3232
infer_end_epoch: 3
3333
#use inference save model
3434
use_inference: False
35-
save_inference_feed_varnames: ["label","C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"]
36-
save_inference_fetch_varnames: ["cast_0.tmp_0"]
35+
save_inference_feed_varnames: ["C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"]
36+
save_inference_fetch_varnames: ["sigmoid_0.tmp_0"]
3737

3838
# hyper parameters of user-defined network
3939
hyper_parameters:

tools/static_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def main(args):
118118
config, use_visual, log_visual, step_num)
119119
metric_str = ""
120120
for var_idx, var_name in enumerate(fetch_vars):
121-
metric_str += "{}: {}, ".format(var_name,
122-
fetch_batch_var[var_idx])
121+
metric_str += "{}: {}, ".format(
122+
var_name, str(fetch_batch_var[var_idx]).strip("[]"))
123123
logger.info("epoch: {} done, ".format(epoch_id) + metric_str +
124124
"epoch time: {:.2f} s".format(time.time() -
125125
epoch_begin))

tools/to_static.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def main(args):
6969
print_interval = config.get("runner.print_interval", None)
7070
model_save_path = config.get("runner.model_save_path", "model_output")
7171
model_init_path = config.get("runner.model_init_path", None)
72+
end_epoch = config.get("runner.infer_end_epoch", 0)
7273

7374
logger.info("**************common.configs**********")
7475
logger.info(
@@ -80,6 +81,7 @@ def main(args):
8081
place = paddle.set_device('gpu' if use_gpu else 'cpu')
8182

8283
dy_model = dy_model_class.create_model(config)
84+
model_save_path = os.path.join(model_save_path, str(end_epoch - 1))
8385

8486
load_model(model_init_path, dy_model)
8587
# example dnn model forward

0 commit comments

Comments
 (0)