Skip to content

Commit af04cbd

Browse files
authored
Merge pull request #367 from frankwhzhang/dy2static
add dygraph to static
2 parents d623610 + bf29659 commit af04cbd

File tree

4 files changed

+162
-1
lines changed

4 files changed

+162
-1
lines changed

models/rank/dnn/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ runner:
2222
use_gpu: False
2323
use_auc: True
2424
train_batch_size: 2
25-
epochs: 3
25+
epochs: 1
2626
print_interval: 2
2727
model_save_path: "output_model_dnn"
2828
infer_batch_size: 2
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# workspace
16+
#workspace: "models/rank/dnn"
17+
18+
19+
runner:
20+
train_data_dir: "data/sample_data/train"
21+
train_reader_path: "criteo_reader" # importlib format
22+
model_init_path: "output_model_dnn/0" # model_init
23+
use_gpu: False
24+
use_auc: True
25+
train_batch_size: 2
26+
epochs: 1
27+
print_interval: 2
28+
model_save_path: "output_model_dnn2" # save path
29+
infer_batch_size: 2
30+
infer_reader_path: "criteo_reader" # importlib format
31+
test_data_dir: "data/sample_data/train"
32+
infer_load_path: "output_model_dnn"
33+
infer_start_epoch: 0
34+
infer_end_epoch: 3
35+
36+
# distribute_config
37+
sync_mode: "async"
38+
split_file_list: False
39+
thread_num: 1
40+
41+
42+
# hyper parameters of user-defined network
43+
hyper_parameters:
44+
# optimizer config
45+
optimizer:
46+
class: Adam
47+
learning_rate: 0.001
48+
strategy: async
49+
# user-defined <key, value> pairs
50+
sparse_inputs_slots: 27
51+
sparse_feature_number: 1000001
52+
sparse_feature_dim: 9
53+
dense_input_dim: 13
54+
fc_sizes: [512, 256, 128, 32]
55+
distributed_embedding: 0

tools/to_static.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
#
16+
# Licensed under the Apache License, Version 2.0 (the "License");
17+
# you may not use this file except in compliance with the License.
18+
# You may obtain a copy of the License at
19+
#
20+
# http://www.apache.org/licenses/LICENSE-2.0
21+
#
22+
# Unless required by applicable law or agreed to in writing, software
23+
# distributed under the License is distributed on an "AS IS" BASIS,
24+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25+
# See the License for the specific language governing permissions and
26+
# limitations under the License.
27+
28+
import paddle
29+
import os
30+
import paddle.nn as nn
31+
import time
32+
import logging
33+
import sys
34+
import importlib
35+
36+
__dir__ = os.path.dirname(os.path.abspath(__file__))
37+
#sys.path.append(__dir__)
38+
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
39+
40+
from utils.utils_single import load_yaml, load_dy_model_class, get_abs_model, create_data_loader
41+
from utils.save_load import load_model, save_model, save_jit_model
42+
from paddle.io import DistributedBatchSampler, DataLoader
43+
import argparse
44+
45+
logging.basicConfig(
46+
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
47+
logger = logging.getLogger(__name__)
48+
49+
50+
def parse_args():
51+
parser = argparse.ArgumentParser(description='paddle-rec run')
52+
parser.add_argument("-m", "--config_yaml", type=str)
53+
args = parser.parse_args()
54+
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
55+
args.config_yaml = get_abs_model(args.config_yaml)
56+
return args
57+
58+
59+
def main(args):
60+
paddle.seed(12345)
61+
# load config
62+
config = load_yaml(args.config_yaml)
63+
dy_model_class = load_dy_model_class(args.abs_dir)
64+
config["config_abs_dir"] = args.abs_dir
65+
# tools.vars
66+
use_gpu = config.get("runner.use_gpu", True)
67+
train_data_dir = config.get("runner.train_data_dir", None)
68+
epochs = config.get("runner.epochs", None)
69+
print_interval = config.get("runner.print_interval", None)
70+
model_save_path = config.get("runner.model_save_path", "model_output")
71+
model_init_path = config.get("runner.model_init_path", None)
72+
73+
logger.info("**************common.configs**********")
74+
logger.info(
75+
"use_gpu: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
76+
format(use_gpu, train_data_dir, epochs, print_interval,
77+
model_save_path))
78+
logger.info("**************common.configs**********")
79+
80+
place = paddle.set_device('gpu' if use_gpu else 'cpu')
81+
82+
dy_model = dy_model_class.create_model(config)
83+
84+
load_model(model_init_path, dy_model)
85+
# example dnn model forward
86+
dy_model = paddle.jit.to_static(
87+
dy_model,
88+
input_spec=[[
89+
paddle.static.InputSpec(
90+
shape=[None, 1], dtype='int64') for jj in range(26)
91+
], paddle.static.InputSpec(
92+
shape=[None, 13], dtype='float32')])
93+
save_jit_model(dy_model, model_save_path, prefix='tostatic')
94+
95+
96+
if __name__ == '__main__':
97+
args = parse_args()
98+
main(args)

tools/utils/save_load.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ def save_model(net, optimizer, model_path, epoch_id, prefix='rec'):
3030
logger.info("Already save model in {}".format(model_path))
3131

3232

33+
def save_jit_model(net, model_path, prefix='tostatic'):
34+
_mkdir_if_not_exist(model_path)
35+
model_prefix = os.path.join(model_path, prefix)
36+
#paddle.save(net.state_dict(), model_prefix + ".pdparams")
37+
paddle.jit.save(net, model_prefix)
38+
logger.info("Already save jit model in {}".format(model_path))
39+
40+
3341
def load_model(model_path, net, prefix='rec'):
3442
logger.info("start load model from {}".format(model_path))
3543
model_prefix = os.path.join(model_path, prefix)

0 commit comments

Comments
 (0)