Skip to content

Commit 763a30f

Browse files
jacquesqiaoreyoung
authored andcommitted
add config_parser_utils
1 parent 843b63b commit 763a30f

File tree

4 files changed

+48
-27
lines changed

4 files changed

+48
-27
lines changed

demo/mnist/api_train.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@
1212
import numpy as np
1313
import random
1414
from mnist_util import read_from_mnist
15-
16-
import paddle.trainer_config_helpers.config_parser as config_parser
15+
import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils
1716
from paddle.trainer_config_helpers import *
1817

1918

2019
def optimizer_config():
2120
settings(
22-
learning_rate=1e-4, learning_method=AdamOptimizer(), batch_size=1000)
21+
learning_rate=1e-4,
22+
learning_method=AdamOptimizer(),
23+
batch_size=1000,
24+
model_average=ModelAverage(average_window=0.5),
25+
regularization=L2Regularization(rate=0.5))
2326

2427

2528
def network_config():
@@ -77,13 +80,14 @@ def main():
7780
# enable_types = [value, gradient, momentum, etc]
7881
# For each optimizer(SGD, Adam), GradientMachine should enable different
7982
# buffers.
80-
opt_config_proto = config_parser.parse_optimizer_config(optimizer_config)
83+
opt_config_proto = config_parser_utils.parse_optimizer_config(
84+
optimizer_config)
8185
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto)
8286
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
8387
enable_types = _temp_optimizer_.getParameterTypes()
8488

8589
# Create Simple Gradient Machine.
86-
model_config = config_parser.parse_network_config(network_config)
90+
model_config = config_parser_utils.parse_network_config(network_config)
8791
m = api.GradientMachine.createFromConfigProto(
8892
model_config, api.CREATE_MODE_NORMAL, enable_types)
8993

demo/mnist/simple_mnist_network.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

python/paddle/trainer_config_helpers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from networks import *
2121
from optimizers import *
2222
from attrs import *
23-
from config_parser import *
23+
from config_parser_utils import *
2424

2525
# This will enable operator overload for LayerOutput
2626
import math as layer_math
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle.trainer.config_parser as config_parser
16+
'''
17+
This file is a wrapper of formal config_parser. The main idea of this file is to
18+
separete different config logic into different function, such as network configuration
19+
and optimizer configuration.
20+
'''
21+
22+
__all__ = [
23+
"parse_trainer_config", "parse_network_config", "parse_optimizer_config"
24+
]
25+
26+
27+
def parse_trainer_config(trainer_conf, config_arg_str):
28+
return config_parser.parse_config(trainer_conf, config_arg_str)
29+
30+
31+
def parse_network_config(network_conf, config_arg_str=''):
32+
config = config_parser.parse_config(network_conf, config_arg_str)
33+
return config.model_config
34+
35+
36+
def parse_optimizer_config(optimizer_conf, config_arg_str=''):
37+
config = config_parser.parse_config(optimizer_conf, config_arg_str)
38+
return config.opt_config

0 commit comments

Comments
 (0)