Skip to content

Commit 91f13e4

Browse files
authored
Merge pull request #1465 from reyoung/feature/tester
Paddle.V2.Trainer.test method complete.
2 parents b63d38d + b9f8cc0 commit 91f13e4

File tree

5 files changed

+107
-73
lines changed

5 files changed

+107
-73
lines changed

demo/mnist/api_train_v2.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,29 @@ def main():
2020

2121
adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
2222

23+
trainer = paddle.trainer.SGD(cost=cost,
24+
parameters=parameters,
25+
update_equation=adam_optimizer)
26+
2327
def event_handler(event):
2428
if isinstance(event, paddle.event.EndIteration):
25-
if event.batch_id % 100 == 0:
26-
print "Pass %d, Batch %d, Cost %f, %s" % (
27-
event.pass_id, event.batch_id, event.cost, event.metrics)
29+
if event.batch_id % 1000 == 0:
30+
result = trainer.test(reader=paddle.reader.batched(
31+
paddle.dataset.mnist.test(), batch_size=256))
32+
33+
print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
34+
event.pass_id, event.batch_id, event.cost, event.metrics,
35+
result.metrics)
36+
2837
else:
2938
pass
3039

31-
trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
32-
3340
trainer.train(
3441
reader=paddle.reader.batched(
3542
paddle.reader.shuffle(
3643
paddle.dataset.mnist.train(), buf_size=8192),
3744
batch_size=32),
38-
cost=cost,
39-
parameters=parameters,
40-
event_handler=event_handler,
41-
reader_dict={images.name: 0,
42-
label.name: 1})
45+
event_handler=event_handler)
4346

4447

4548
if __name__ == '__main__':

python/paddle/v2/dataset/mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
1111
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
12-
TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6'
12+
TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3'
1313
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
14-
TEST_LABEL_MD5 = '4e9511fe019b2189026bd0421ba7b688'
14+
TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c'
1515
TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
1616
TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
1717
TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'

python/paddle/v2/event.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
TODO(yuyang18): Complete it!
1212
"""
1313
import py_paddle.swig_paddle as api
14-
__all__ = ['EndIteration', 'BeginIteration', 'BeginPass', 'EndPass']
14+
15+
__all__ = [
16+
'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult'
17+
]
1518

1619

1720
class WithMetric(object):
@@ -30,6 +33,11 @@ def metrics(self):
3033
return retv
3134

3235

36+
class TestResult(WithMetric):
37+
def __init__(self, evaluator):
38+
super(TestResult, self).__init__(evaluator)
39+
40+
3341
class BeginPass(object):
3442
"""
3543
Event On One Pass Training Start.

python/paddle/v2/topology.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
__all__ = ['Topology']
2222

2323

24+
def __bfs_travel__(callback, *layers):
25+
for each_layer in layers:
26+
__break__ = callback(each_layer)
27+
if __break__:
28+
return
29+
__bfs_travel__(callback, *each_layer.__parent_layers__.values())
30+
31+
2432
class Topology(object):
2533
"""
2634
Topology is used to store the information about all layers
@@ -46,48 +54,41 @@ def get_layer(self, name):
4654
:param name:
4755
:return:
4856
"""
49-
result_layer = []
57+
result_layer = [None]
5058

51-
def find_layer_by_name(layer, layer_name):
52-
if len(result_layer) == 1:
53-
return
54-
elif layer.name == layer_name:
55-
result_layer.append(layer)
56-
else:
57-
for parent_layer in layer.__parent_layers__.values():
58-
find_layer_by_name(parent_layer, layer_name)
59+
def __impl__(l):
60+
if l.name == name:
61+
result_layer[0] = l
62+
return True # break
63+
return False
5964

60-
for layer in self.layers:
61-
find_layer_by_name(layer, name)
62-
63-
assert len(result_layer) == 1
65+
__bfs_travel__(__impl__, *self.layers)
66+
if result_layer[0] is None:
67+
raise ValueError("No such layer %s" % name)
6468
return result_layer[0]
6569

6670
def data_layers(self):
6771
"""
6872
get all data layer
6973
:return:
7074
"""
71-
data_layers = set()
72-
73-
def find_data_layer(layer):
74-
if isinstance(layer, v2_layer.DataLayerV2):
75-
data_layers.add(layer)
76-
for parent_layer in layer.__parent_layers__.values():
77-
find_data_layer(parent_layer)
75+
data_layers = dict()
7876

79-
for layer in self.layers:
80-
find_data_layer(layer)
77+
def __impl__(l):
78+
if isinstance(l, v2_layer.DataLayerV2):
79+
data_layers[l.name] = l
8180

81+
__bfs_travel__(__impl__, *self.layers)
8282
return data_layers
8383

8484
def data_type(self):
8585
"""
8686
get data_type from proto, such as:
8787
[('image', dense_vector(768)), ('label', integer_value(10))]
8888
"""
89-
return [(data_layer.name, data_layer.type)
90-
for data_layer in self.data_layers()]
89+
data_layers = self.data_layers()
90+
return [(nm, data_layers[nm].type)
91+
for nm in self.proto().input_layer_names]
9192

9293

9394
def __check_layer_type__(layer):

python/paddle/v2/trainer.py

Lines changed: 58 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -42,25 +42,35 @@ def train(self, reader, topology, parameters, event_handler=None):
4242

4343

4444
class SGD(ITrainer):
45-
def __init__(self, update_equation):
45+
def __init__(self, cost, parameters, update_equation):
4646
"""
4747
Simple SGD Trainer.
4848
4949
:param update_equation: The optimizer object.
5050
:type update_equation: v2_optimizer.Optimizer
5151
"""
52+
53+
if not isinstance(parameters, v2_parameters.Parameters):
54+
raise TypeError('parameters should be parameters')
55+
5256
if not isinstance(update_equation, v2_optimizer.Optimizer):
53-
raise ValueError("update equation parameter must be "
54-
"paddle.v2.optimizer.Optimizer")
57+
raise TypeError("update equation parameter must be "
58+
"paddle.v2.optimizer.Optimizer")
59+
topology = Topology(cost)
5560
self.__optimizer__ = update_equation
61+
self.__topology__ = topology
62+
self.__parameters__ = parameters
63+
self.__topology_in_proto__ = topology.proto()
64+
self.__data_types__ = topology.data_type()
65+
gm = api.GradientMachine.createFromConfigProto(
66+
self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
67+
self.__optimizer__.enable_types())
68+
assert isinstance(gm, api.GradientMachine)
69+
parameters.append_gradient_machine(gm)
70+
self.__gradient_machine__ = gm
71+
self.__gradient_machine__.randParameters()
5672

57-
def train(self,
58-
reader,
59-
cost,
60-
parameters,
61-
num_passes=1,
62-
event_handler=None,
63-
reader_dict=None):
73+
def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
6474
"""
6575
Training method. Will train num_passes of input data.
6676
@@ -76,44 +86,41 @@ def train(self,
7686
if event_handler is None:
7787
event_handler = default_event_handler
7888

79-
topology = Topology(cost)
89+
if reader_dict is None:
90+
reader_dict = self.default_reader_dict()
8091

8192
__check_train_args__(**locals())
8293

83-
gm = api.GradientMachine.createFromConfigProto(
84-
topology.proto(), api.CREATE_MODE_NORMAL,
85-
self.__optimizer__.enable_types())
86-
assert isinstance(gm, api.GradientMachine)
87-
parameters.append_gradient_machine(gm)
88-
gm.randParameters()
8994
updater = self.__optimizer__.create_local_updater()
90-
updater.init(gm)
95+
updater.init(self.__gradient_machine__)
9196

92-
gm.start()
93-
batch_evaluator = gm.makeEvaluator()
97+
self.__gradient_machine__.start()
98+
batch_evaluator = self.__gradient_machine__.makeEvaluator()
9499
assert isinstance(batch_evaluator, api.Evaluator)
95-
pass_evaluator = gm.makeEvaluator()
100+
pass_evaluator = self.__gradient_machine__.makeEvaluator()
96101
assert isinstance(pass_evaluator, api.Evaluator)
97102
out_args = api.Arguments.createArguments(0)
98103

99-
feeder = DataFeeder(topology.data_type(), reader_dict)
104+
feeder = DataFeeder(self.__data_types__, reader_dict)
100105

101106
for pass_id in xrange(num_passes):
102107
event_handler(v2_event.BeginPass(pass_id))
103108
pass_evaluator.start()
104109
updater.startPass()
105110
for batch_id, data_batch in enumerate(reader()):
106111
pass_type = updater.startBatch(len(data_batch))
107-
gm.forwardBackward(feeder(data_batch), out_args, pass_type)
112+
self.__gradient_machine__.forwardBackward(
113+
feeder(data_batch), out_args, pass_type)
108114
batch_evaluator.start()
109115
event_handler(
110116
v2_event.BeginIteration(
111117
pass_id=pass_id, batch_id=batch_id))
112118
pass_type = updater.startBatch(len(data_batch))
113-
gm.forwardBackward(feeder(data_batch), out_args, pass_type)
114-
gm.eval(pass_evaluator)
115-
gm.eval(batch_evaluator)
116-
for each_param in gm.getParameters():
119+
self.__gradient_machine__.forwardBackward(
120+
feeder(data_batch), out_args, pass_type)
121+
self.__gradient_machine__.eval(pass_evaluator)
122+
self.__gradient_machine__.eval(batch_evaluator)
123+
for each_param in self.__gradient_machine__.getParameters():
117124
updater.update(each_param)
118125
# Get cost. We use numpy to calculate total cost for this batch.
119126
cost_vec = out_args.getSlotValue(0)
@@ -131,22 +138,37 @@ def train(self,
131138
updater.finishPass()
132139
pass_evaluator.finish()
133140
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
134-
gm.finish()
141+
self.__gradient_machine__.finish()
142+
143+
def default_reader_dict(self):
144+
reader_dict = dict()
145+
for i, tp in enumerate(self.__data_types__):
146+
reader_dict[tp[0]] = i
147+
return reader_dict
148+
149+
def test(self, reader, reader_dict=None):
150+
if reader_dict is None:
151+
reader_dict = self.default_reader_dict()
152+
153+
feeder = DataFeeder(self.__data_types__, reader_dict)
154+
evaluator = self.__gradient_machine__.makeEvaluator()
155+
out_args = api.Arguments.createArguments(0)
156+
evaluator.start()
157+
for data_batch in reader():
158+
self.__gradient_machine__.forward(
159+
feeder(data_batch), out_args, api.PASS_TEST)
160+
self.__gradient_machine__.eval(evaluator)
161+
162+
evaluator.finish()
163+
return v2_event.TestResult(evaluator=evaluator)
135164

136165

137-
def __check_train_args__(reader, topology, parameters, event_handler, **kwargs):
166+
def __check_train_args__(reader, event_handler, **kwargs):
138167
"""
139168
Check train function's argument types
140169
"""
141170
if not callable(reader) or not isinstance(reader(), collections.Iterator):
142171
raise TypeError('train_data_reader should be a function, '
143172
'which can return a iterator')
144-
145-
if not isinstance(topology, Topology):
146-
raise TypeError('topology should be a model config')
147-
148-
if not isinstance(parameters, v2_parameters.Parameters):
149-
raise TypeError('parameters should be a parameter pool')
150-
151173
if not callable(event_handler):
152174
raise TypeError('event handler should be a function')

0 commit comments

Comments
 (0)