Skip to content

Commit 7e6def5

Browse files
authored
Merge pull request #1756 from qingqing01/v2_api_multi_leaf_node
Add extra_layers in paddle.trainer.SGD.
2 parents 8a01644 + 23283f2 commit 7e6def5

File tree

4 files changed

+60
-28
lines changed

4 files changed

+60
-28
lines changed

python/paddle/v2/layer.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,28 +53,41 @@
5353
__all__ = ['parse_network', 'data']
5454

5555

56-
def parse_network(*outputs):
56+
def parse_network(output_layers, extra_layers=None):
5757
"""
58-
Parse all output layers and then generate a ModelConfig object.
58+
Parse all layers in the neural network graph and
59+
then generate a ModelConfig object.
5960
6061
.. note::
6162
6263
This function is used internally in paddle.v2 module. User should never
6364
invoke this method.
6465
65-
:param outputs: Output layers.
66-
:type outputs: Layer
66+
:param output_layers: Output layers.
67+
:type output_layers: Layer
68+
:param extra_layers: Some layers in the neural network graph are not in the
69+
path of output_layers.
70+
:type extra_layers: Layer
6771
:return: A ModelConfig object instance.
6872
:rtype: ModelConfig
6973
"""
74+
if not isinstance(output_layers, collections.Sequence):
75+
output_layers = [output_layers]
76+
if extra_layers is not None and not isinstance(extra_layers,
77+
collections.Sequence):
78+
extra_layers = [extra_layers]
7079

7180
def __real_func__():
7281
"""
7382
__real_func__ is the function that config_parser.parse invoked. It is
7483
the plain old paddle configuration function.
7584
"""
7685
context = dict()
77-
real_output = [each.to_proto(context=context) for each in outputs]
86+
real_output = [each.to_proto(context=context) for each in output_layers]
87+
if extra_layers is not None:
88+
extra_output = [
89+
each.to_proto(context=context) for each in extra_layers
90+
]
7891
conf_helps.outputs(real_output)
7992

8093
return __parse__(__real_func__)

python/paddle/v2/tests/test_layer.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ def test_pooling_layer(self):
5959
num_channels=16,
6060
pool_type=pooling.Max())
6161
maxout = layer.maxout(input=conv, num_channels=16, groups=4)
62-
print layer.parse_network(maxpool, spp, maxout)
62+
print layer.parse_network([maxpool, spp, maxout])
6363

6464
def test_norm_layer(self):
6565
norm1 = layer.img_cmrnorm(input=conv, size=5)
6666
norm2 = layer.batch_norm(input=conv)
6767
norm3 = layer.sum_to_one_norm(input=conv)
68-
print layer.parse_network(norm1, norm2, norm3)
68+
print layer.parse_network([norm1, norm2, norm3])
6969

7070

7171
class AggregateLayerTest(unittest.TestCase):
@@ -78,7 +78,8 @@ def test_aggregate_layer(self):
7878
first_seq = layer.first_seq(input=pixel)
7979
concat = layer.concat(input=[last_seq, first_seq])
8080
seq_concat = layer.seq_concat(a=last_seq, b=first_seq)
81-
print layer.parse_network(pool, last_seq, first_seq, concat, seq_concat)
81+
print layer.parse_network(
82+
[pool, last_seq, first_seq, concat, seq_concat])
8283

8384

8485
class MathLayerTest(unittest.TestCase):
@@ -95,8 +96,10 @@ def test_math_layer(self):
9596
tensor = layer.tensor(a=pixel, b=pixel, size=1000)
9697
cos_sim = layer.cos_sim(a=pixel, b=pixel)
9798
trans = layer.trans(input=tensor)
98-
print layer.parse_network(addto, linear_comb, interpolation, power,
99-
scaling, slope, tensor, cos_sim, trans)
99+
print layer.parse_network([
100+
addto, linear_comb, interpolation, power, scaling, slope, tensor,
101+
cos_sim, trans
102+
])
100103

101104

102105
class ReshapeLayerTest(unittest.TestCase):
@@ -110,7 +113,8 @@ def test_reshape_layer(self):
110113
repeat = layer.repeat(input=pixel, num_repeats=4)
111114
reshape = layer.seq_reshape(input=pixel, reshape_size=4)
112115
rotate = layer.rotate(input=pixel, height=16, width=49)
113-
print layer.parse_network(block_expand, expand, repeat, reshape, rotate)
116+
print layer.parse_network(
117+
[block_expand, expand, repeat, reshape, rotate])
114118

115119

116120
class RecurrentLayerTest(unittest.TestCase):
@@ -119,7 +123,7 @@ def test_recurrent_layer(self):
119123
recurrent = layer.recurrent(input=word)
120124
lstm = layer.lstmemory(input=word)
121125
gru = layer.grumemory(input=word)
122-
print layer.parse_network(recurrent, lstm, gru)
126+
print layer.parse_network([recurrent, lstm, gru])
123127

124128

125129
class CostLayerTest(unittest.TestCase):
@@ -139,10 +143,10 @@ def test_cost_layer(self):
139143
cost10 = layer.sum_cost(input=inference)
140144
cost11 = layer.huber_cost(input=score, label=label)
141145

142-
print layer.parse_network(cost1, cost2)
143-
print layer.parse_network(cost3, cost4)
144-
print layer.parse_network(cost5, cost6)
145-
print layer.parse_network(cost7, cost8, cost9, cost10, cost11)
146+
print layer.parse_network([cost1, cost2])
147+
print layer.parse_network([cost3, cost4])
148+
print layer.parse_network([cost5, cost6])
149+
print layer.parse_network([cost7, cost8, cost9, cost10, cost11])
146150

147151
crf = layer.crf(input=inference, label=label)
148152
crf_decoding = layer.crf_decoding(input=inference, size=3)
@@ -151,16 +155,16 @@ def test_cost_layer(self):
151155
nce = layer.nce(input=inference, label=label, num_classes=3)
152156
hsigmoid = layer.hsigmoid(input=inference, label=label, num_classes=3)
153157

154-
print layer.parse_network(crf, crf_decoding, ctc, warp_ctc, nce,
155-
hsigmoid)
158+
print layer.parse_network(
159+
[crf, crf_decoding, ctc, warp_ctc, nce, hsigmoid])
156160

157161

158162
class OtherLayerTest(unittest.TestCase):
159163
def test_sampling_layer(self):
160164
maxid = layer.max_id(input=inference)
161165
sampling_id = layer.sampling_id(input=inference)
162166
eos = layer.eos(input=maxid, eos_id=5)
163-
print layer.parse_network(maxid, sampling_id, eos)
167+
print layer.parse_network([maxid, sampling_id, eos])
164168

165169
def test_slicing_joining_layer(self):
166170
pad = layer.pad(input=conv, pad_c=[2, 3], pad_h=[1, 2], pad_w=[3, 1])

python/paddle/v2/topology.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,26 @@ class Topology(object):
5353
and network configs.
5454
"""
5555

56-
def __init__(self, layers):
57-
if not isinstance(layers, collections.Sequence):
58-
__check_layer_type__(layers)
59-
layers = [layers]
60-
for layer in layers:
61-
__check_layer_type__(layer)
56+
def __init__(self, layers, extra_layers=None):
57+
def __check__(layers):
58+
if not isinstance(layers, collections.Sequence):
59+
__check_layer_type__(layers)
60+
layers = [layers]
61+
for layer in layers:
62+
__check_layer_type__(layer)
63+
return layers
64+
65+
layers = __check__(layers)
6266
self.layers = layers
63-
self.__model_config__ = v2_layer.parse_network(*layers)
67+
if extra_layers is not None:
68+
extra_layers = __check__(extra_layers)
69+
70+
self.__model_config__ = v2_layer.parse_network(
71+
layers, extra_layers=extra_layers)
72+
73+
if extra_layers is not None:
74+
self.layers.extend(extra_layers)
75+
6476
assert isinstance(self.__model_config__, ModelConfig)
6577

6678
def proto(self):

python/paddle/v2/trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,20 @@ class SGD(object):
3737
:type cost: paddle.v2.config_base.Layer
3838
:param parameters: The parameters dictionary.
3939
:type parameters: paddle.v2.parameters.Parameters
40+
:param extra_layers: Some layers in the neural network graph are not
41+
in the path of cost layer.
42+
:type extra_layers: paddle.v2.config_base.Layer
4043
"""
4144

42-
def __init__(self, cost, parameters, update_equation):
45+
def __init__(self, cost, parameters, update_equation, extra_layers=None):
4346

4447
if not isinstance(parameters, v2_parameters.Parameters):
4548
raise TypeError('parameters should be parameters')
4649

4750
if not isinstance(update_equation, v2_optimizer.Optimizer):
4851
raise TypeError("update equation parameter must be "
4952
"paddle.v2.optimizer.Optimizer")
50-
topology = Topology(cost)
53+
topology = Topology(cost, extra_layers=extra_layers)
5154
self.__optimizer__ = update_equation
5255
self.__topology__ = topology
5356
self.__parameters__ = parameters

0 commit comments

Comments
 (0)