Skip to content

Commit 843b63b

Browse files
jacquesqiaoreyoung
authored andcommitted
add config_parser in trainer_config_helpers to seperate trainer config
1 parent 3a80272 commit 843b63b

File tree

4 files changed

+98
-39
lines changed

4 files changed

+98
-39
lines changed

demo/mnist/api_train.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,29 @@
99
import py_paddle.swig_paddle as api
1010
from py_paddle import DataProviderConverter
1111
import paddle.trainer.PyDataProvider2 as dp
12-
import paddle.trainer.config_parser
1312
import numpy as np
1413
import random
1514
from mnist_util import read_from_mnist
1615

16+
import paddle.trainer_config_helpers.config_parser as config_parser
17+
from paddle.trainer_config_helpers import *
18+
19+
20+
def optimizer_config():
21+
settings(
22+
learning_rate=1e-4, learning_method=AdamOptimizer(), batch_size=1000)
23+
24+
25+
def network_config():
26+
imgs = data_layer(name='pixel', size=784)
27+
hidden1 = fc_layer(input=imgs, size=200)
28+
hidden2 = fc_layer(input=hidden1, size=200)
29+
inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation())
30+
cost = classification_cost(
31+
input=inference, label=data_layer(
32+
name='label', size=10))
33+
outputs(cost)
34+
1735

1836
def init_parameter(network):
1937
assert isinstance(network, api.GradientMachine)
@@ -54,20 +72,20 @@ def input_order_converter(generator):
5472

5573
def main():
5674
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores
57-
config = paddle.trainer.config_parser.parse_config(
58-
'simple_mnist_network.py', '')
5975

6076
# get enable_types for each optimizer.
6177
# enable_types = [value, gradient, momentum, etc]
6278
# For each optimizer(SGD, Adam), GradientMachine should enable different
6379
# buffers.
64-
opt_config = api.OptimizationConfig.createFromProto(config.opt_config)
80+
opt_config_proto = config_parser.parse_optimizer_config(optimizer_config)
81+
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto)
6582
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
6683
enable_types = _temp_optimizer_.getParameterTypes()
6784

6885
# Create Simple Gradient Machine.
86+
model_config = config_parser.parse_network_config(network_config)
6987
m = api.GradientMachine.createFromConfigProto(
70-
config.model_config, api.CREATE_MODE_NORMAL, enable_types)
88+
model_config, api.CREATE_MODE_NORMAL, enable_types)
7189

7290
# This type check is not useful. Only enable type hint in IDE.
7391
# Such as PyCharm

python/paddle/trainer/config_parser.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3416,8 +3416,35 @@ def register_parse_config_hook(f):
34163416
_parse_config_hooks.add(f)
34173417

34183418

3419-
def parse_config(config_file, config_arg_str):
3419+
def update_g_config():
34203420
'''
3421+
Update g_config after execute config_file or config_functions.
3422+
'''
3423+
for k, v in settings.iteritems():
3424+
if v is None:
3425+
continue
3426+
g_config.opt_config.__setattr__(k, v)
3427+
3428+
for k, v in trainer_settings.iteritems():
3429+
if v is None:
3430+
continue
3431+
g_config.__setattr__(k, v)
3432+
3433+
for name in g_config.model_config.input_layer_names:
3434+
assert name in g_layer_map, \
3435+
'input name "%s" does not correspond to a layer name' % name
3436+
assert (g_layer_map[name].type == "data" or g_layer_map[name].type == "data_trim"), \
3437+
'The type of input layer "%s" is not "data"' % name
3438+
for name in g_config.model_config.output_layer_names:
3439+
assert name in g_layer_map, \
3440+
'input name "%s" does not correspond to a layer name' % name
3441+
return g_config
3442+
3443+
3444+
def parse_config(trainer_config, config_arg_str):
3445+
'''
3446+
@param trainer_config: can be a string of config file name or a function name
3447+
with config logic
34213448
@param config_arg_str: a string of the form var1=val1,var2=val2. It will be
34223449
passed to config script as a dictionary CONFIG_ARGS
34233450
'''
@@ -3451,45 +3478,20 @@ def parse_config(config_file, config_arg_str):
34513478
g_root_submodel.is_recurrent_layer_group = False
34523479
g_current_submodel = g_root_submodel
34533480

3454-
# for paddle on spark, need support non-file config.
3455-
# you can use parse_config like below:
3456-
#
3457-
# from paddle.trainer.config_parser import parse_config
3458-
# def configs():
3459-
# #your paddle config code, which is same as config file.
3460-
#
3461-
# config = parse_config(configs, "is_predict=1")
3462-
# # then you get config proto object.
3463-
if hasattr(config_file, '__call__'):
3464-
config_file.func_globals.update(
3481+
if hasattr(trainer_config, '__call__'):
3482+
trainer_config.func_globals.update(
34653483
make_config_environment("", config_args))
3466-
config_file()
3484+
trainer_config()
34673485
else:
3468-
execfile(config_file, make_config_environment(config_file, config_args))
3469-
for k, v in settings.iteritems():
3470-
if v is None:
3471-
continue
3472-
g_config.opt_config.__setattr__(k, v)
3473-
3474-
for k, v in trainer_settings.iteritems():
3475-
if v is None:
3476-
continue
3477-
g_config.__setattr__(k, v)
3486+
execfile(trainer_config,
3487+
make_config_environment(trainer_config, config_args))
34783488

3479-
for name in g_config.model_config.input_layer_names:
3480-
assert name in g_layer_map, \
3481-
'input name "%s" does not correspond to a layer name' % name
3482-
assert (g_layer_map[name].type == "data" or g_layer_map[name].type == "data_trim"), \
3483-
'The type of input layer "%s" is not "data"' % name
3484-
for name in g_config.model_config.output_layer_names:
3485-
assert name in g_layer_map, \
3486-
'input name "%s" does not correspond to a layer name' % name
3487-
return g_config
3489+
return update_g_config()
34883490

34893491

3490-
def parse_config_and_serialize(config_file, config_arg_str):
3492+
def parse_config_and_serialize(trainer_config, config_arg_str):
34913493
try:
3492-
config = parse_config(config_file, config_arg_str)
3494+
config = parse_config(trainer_config, config_arg_str)
34933495
#logger.info(config)
34943496
return config.SerializeToString()
34953497
except:

python/paddle/trainer_config_helpers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from networks import *
2121
from optimizers import *
2222
from attrs import *
23+
from config_parser import *
2324

2425
# This will enable operator overload for LayerOutput
2526
import math as layer_math
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle.trainer.config_parser as config_parser
16+
'''
17+
This file is a wrapper of formal config_parser. The main idea of this file is to
18+
separete different config logic into different function, such as network configuration
19+
and optimizer configuration.
20+
'''
21+
22+
__all__ = [
23+
"parse_trainer_config", "parse_network_config", "parse_optimizer_config"
24+
]
25+
26+
27+
def parse_trainer_config(trainer_conf, config_arg_str):
28+
return config_parser.parse_config(trainer_conf, config_arg_str)
29+
30+
31+
def parse_network_config(network_conf):
32+
config = config_parser.parse_config(network_conf, '')
33+
return config.model_config
34+
35+
36+
def parse_optimizer_config(optimizer_conf):
37+
config = config_parser.parse_config(optimizer_conf, '')
38+
return config.opt_config

0 commit comments

Comments
 (0)