Skip to content

Commit cdecd53

Browse files
authored
Merge pull request #1498 from reyoung/feature/expose_networks
Feature/expose networks
2 parents 5c63418 + 3590cb1 commit cdecd53

File tree

4 files changed

+70
-1
lines changed

4 files changed

+70
-1
lines changed

python/paddle/v2/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
from . import reader
2525
import attr
2626
import pooling
27+
import networks
2728
import py_paddle.swig_paddle as api
2829

2930
__all__ = [
3031
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
3132
'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader',
32-
'topology'
33+
'topology', 'networks'
3334
]
3435

3536

python/paddle/v2/config_base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import collections
216

317
from paddle.trainer_config_helpers.default_decorators import wrap_name_default

python/paddle/v2/networks.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle.trainer_config_helpers.networks as conf_nw
16+
import inspect
17+
from config_base import __convert_to_v2__
18+
19+
__all__ = []
20+
21+
22+
def __initialize__():
23+
for each_subnetwork in conf_nw.__all__:
24+
if each_subnetwork in ['inputs', 'outputs']:
25+
continue
26+
func = getattr(conf_nw, each_subnetwork)
27+
if hasattr(func, 'argspec'):
28+
argspec = func.argspec
29+
else:
30+
argspec = inspect.getargspec(func)
31+
if each_subnetwork == 'simple_attention':
32+
parents = ['encoded_sequence', 'encoded_proj', 'decoder_state']
33+
else:
34+
parents = filter(lambda x: x.startswith('input'), argspec.args)
35+
assert len(parents) != 0, each_subnetwork
36+
v2_subnet = __convert_to_v2__(
37+
each_subnetwork,
38+
parent_names=parents,
39+
is_default_name='name' in argspec.args)
40+
globals()[each_subnetwork] = v2_subnet
41+
global __all__
42+
__all__.append(each_subnetwork)
43+
44+
45+
__initialize__()

python/paddle/v2/tests/test_layer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import paddle.v2.data_type as data_type
1919
import paddle.v2.layer as layer
2020
import paddle.v2.pooling as pooling
21+
import paddle.v2.networks as networks
2122

2223
pixel = layer.data(name='pixel', type=data_type.dense_vector(128))
2324
label = layer.data(name='label', type=data_type.integer_value(10))
@@ -251,5 +252,13 @@ def test_operator(self):
251252
print layer.parse_network(conv1)
252253

253254

255+
class NetworkTests(unittest.TestCase):
256+
def test_vgg(self):
257+
img = layer.data(name='pixel', type=data_type.dense_vector(784))
258+
vgg_out = networks.small_vgg(
259+
input_image=img, num_channels=1, num_classes=2)
260+
print layer.parse_network(vgg_out)
261+
262+
254263
if __name__ == '__main__':
255264
unittest.main()

0 commit comments

Comments
 (0)