Skip to content

Commit aba84aa

Browse files
committed
Use Chain to refactor trainer.
1 parent 9b41b08 commit aba84aa

File tree

2 files changed

+247
-110
lines changed

2 files changed

+247
-110
lines changed

demo/mnist/api_train.py

Lines changed: 138 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from mnist_util import read_from_mnist
1515
from paddle.trainer_config_helpers import *
1616

17+
from trainer import *
18+
1719

1820
def optimizer_config():
1921
settings(
@@ -72,122 +74,132 @@ def input_order_converter(generator):
7274
yield each_item['pixel'], each_item['label']
7375

7476

75-
def main():
76-
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores
77-
78-
# get enable_types for each optimizer.
79-
# enable_types = [value, gradient, momentum, etc]
80-
# For each optimizer(SGD, Adam), GradientMachine should enable different
81-
# buffers.
82-
opt_config_proto = parse_optimizer_config(optimizer_config)
83-
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto)
84-
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
85-
enable_types = _temp_optimizer_.getParameterTypes()
86-
87-
# Create Simple Gradient Machine.
88-
model_config = parse_network_config(network_config)
89-
m = api.GradientMachine.createFromConfigProto(
90-
model_config, api.CREATE_MODE_NORMAL, enable_types)
91-
92-
# This type check is not useful. Only enable type hint in IDE.
93-
# Such as PyCharm
94-
assert isinstance(m, api.GradientMachine)
95-
96-
# Initialize Parameter by numpy.
97-
init_parameter(network=m)
98-
99-
# Create Local Updater. Local means not run in cluster.
100-
# For a cluster training, here we can change to createRemoteUpdater
101-
# in future.
102-
updater = api.ParameterUpdater.createLocalUpdater(opt_config)
103-
assert isinstance(updater, api.ParameterUpdater)
104-
105-
# Initialize ParameterUpdater.
106-
updater.init(m)
107-
108-
# DataProvider Converter is a utility convert Python Object to Paddle C++
109-
# Input. The input format is as same as Paddle's DataProvider.
110-
converter = DataProviderConverter(
111-
input_types=[dp.dense_vector(784), dp.integer_value(10)])
112-
113-
train_file = './data/raw_data/train'
114-
test_file = './data/raw_data/t10k'
115-
116-
# start gradient machine.
117-
# the gradient machine must be started before invoke forward/backward.
118-
# not just for training, but also for inference.
119-
m.start()
120-
121-
# evaluator can print error rate, etc. It is a C++ class.
122-
batch_evaluator = m.makeEvaluator()
123-
test_evaluator = m.makeEvaluator()
124-
125-
# Get Train Data.
126-
# TrainData will stored in a data pool. Currently implementation is not care
127-
# about memory, speed. Just a very naive implementation.
128-
train_data_generator = input_order_converter(read_from_mnist(train_file))
129-
train_data = BatchPool(train_data_generator, 512)
130-
131-
# outArgs is Neural Network forward result. Here is not useful, just passed
132-
# to gradient_machine.forward
133-
outArgs = api.Arguments.createArguments(0)
134-
135-
for pass_id in xrange(2): # we train 2 passes.
136-
updater.startPass()
137-
138-
for batch_id, data_batch in enumerate(train_data()):
139-
# data_batch is input images.
140-
# here, for online learning, we could get data_batch from network.
141-
142-
# Start update one batch.
143-
pass_type = updater.startBatch(len(data_batch))
144-
145-
# Start BatchEvaluator.
146-
# batch_evaluator can be used between start/finish.
147-
batch_evaluator.start()
148-
149-
# forwardBackward is a shortcut for forward and backward.
150-
# It is sometimes faster than invoke forward/backward separately,
151-
# because in GradientMachine, it may be async.
152-
m.forwardBackward(converter(data_batch), outArgs, pass_type)
153-
154-
for each_param in m.getParameters():
155-
updater.update(each_param)
156-
157-
# Get cost. We use numpy to calculate total cost for this batch.
158-
cost_vec = outArgs.getSlotValue(0)
159-
cost_vec = cost_vec.copyToNumpyMat()
160-
cost = cost_vec.sum() / len(data_batch)
161-
162-
# Make evaluator works.
163-
m.eval(batch_evaluator)
164-
165-
# Print logs.
166-
print 'Pass id', pass_id, 'Batch id', batch_id, 'with cost=', \
167-
cost, batch_evaluator
168-
169-
batch_evaluator.finish()
170-
# Finish batch.
171-
# * will clear gradient.
172-
# * ensure all values should be updated.
173-
updater.finishBatch(cost)
174-
77+
class MonolithicChainItem(RunnerChainItem):
78+
def finalize(self, context, next_callback):
79+
context.gradient_machine.finish()
80+
81+
def initialize(self, context, next_callback):
82+
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores
83+
84+
# get enable_types for each optimizer.
85+
# enable_types = [value, gradient, momentum, etc]
86+
# For each optimizer(SGD, Adam), GradientMachine should enable different
87+
# buffers.
88+
opt_config_proto = parse_optimizer_config(optimizer_config)
89+
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto)
90+
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
91+
enable_types = _temp_optimizer_.getParameterTypes()
92+
93+
# Create Simple Gradient Machine.
94+
model_config = parse_network_config(network_config)
95+
context.gradient_machine = api.GradientMachine.createFromConfigProto(
96+
model_config, api.CREATE_MODE_NORMAL, enable_types)
97+
98+
# This type check is not useful. Only enable type hint in IDE.
99+
# Such as PyCharm
100+
assert isinstance(context.gradient_machine, api.GradientMachine)
101+
102+
# Initialize Parameter by numpy.
103+
init_parameter(network=context.gradient_machine)
104+
105+
# Create Local Updater. Local means not run in cluster.
106+
# For a cluster training, here we can change to createRemoteUpdater
107+
# in future.
108+
context.updater = api.ParameterUpdater.createLocalUpdater(opt_config)
109+
assert isinstance(context.updater, api.ParameterUpdater)
110+
context.updater.init(context.gradient_machine)
111+
112+
# DataProvider Converter is a utility convert Python Object to Paddle C++
113+
# Input. The input format is as same as Paddle's DataProvider.
114+
context.data_converter = DataProviderConverter(
115+
input_types=[dp.dense_vector(784), dp.integer_value(10)])
116+
117+
train_file = './data/raw_data/train'
118+
test_file = './data/raw_data/t10k'
119+
120+
context.gradient_machine.start()
121+
122+
# Get Train Data.
123+
# TrainData will stored in a data pool. Currently implementation is not care
124+
# about memory, speed. Just a very naive implementation.
125+
train_data_generator = input_order_converter(
126+
read_from_mnist(train_file))
127+
train_data = BatchPool(train_data_generator, 512)
128+
context.train_data_callback = train_data
129+
context.test_file = test_file
130+
131+
next_callback(context)
132+
133+
def on_batch_begin(self, context, next_callback):
134+
batch_evaluator = context.gradient_machine.makeEvaluator()
135+
# outArgs is Neural Network forward result. Here is not useful, just passed
136+
# to gradient_machine.forward
137+
outArgs = api.Arguments.createArguments(0)
138+
139+
try:
140+
data_batch = next(context.train_data)
141+
except StopIteration:
142+
return True
143+
144+
# data_batch is input images.
145+
# here, for online learning, we could get data_batch from network.
146+
147+
# Start update one batch.
148+
pass_type = context.updater.startBatch(len(data_batch))
149+
150+
# Start BatchEvaluator.
151+
# batch_evaluator can be used between start/finish.
152+
batch_evaluator.start()
153+
154+
# forwardBackward is a shortcut for forward and backward.
155+
# It is sometimes faster than invoke forward/backward separately,
156+
# because in GradientMachine, it may be async.
157+
context.gradient_machine.forwardBackward(
158+
context.data_converter(data_batch), outArgs, pass_type)
159+
160+
for each_param in context.gradient_machine.getParameters():
161+
context.updater.update(each_param)
162+
163+
# Get cost. We use numpy to calculate total cost for this batch.
164+
cost_vec = outArgs.getSlotValue(0)
165+
cost_vec = cost_vec.copyToNumpyMat()
166+
cost = cost_vec.sum() / len(data_batch)
167+
168+
# Make evaluator works.
169+
context.gradient_machine.eval(batch_evaluator)
170+
171+
# Print logs.
172+
print 'batch with cost=', cost, batch_evaluator
173+
174+
batch_evaluator.finish()
175+
context.cost = cost
176+
return False
177+
178+
def on_pass_begin(self, context, next_callback):
179+
context.updater.startPass()
180+
context.train_data = context.train_data_callback()
181+
182+
def on_pass_end(self, context, next_callback):
175183
# testing stage. use test data set to test current network.
176-
updater.apply()
184+
outArgs = api.Arguments.createArguments(0)
185+
context.updater.apply()
186+
test_evaluator = context.gradient_machine.makeEvaluator()
177187
test_evaluator.start()
178-
test_data_generator = input_order_converter(read_from_mnist(test_file))
188+
test_data_generator = input_order_converter(
189+
read_from_mnist(context.test_file))
179190
for data_batch in generator_to_batch(test_data_generator, 512):
180191
# in testing stage, only forward is needed.
181-
m.forward(converter(data_batch), outArgs, api.PASS_TEST)
182-
m.eval(test_evaluator)
192+
context.gradient_machine.forward(
193+
context.data_converter(data_batch), outArgs, api.PASS_TEST)
194+
context.gradient_machine.eval(test_evaluator)
183195

184196
# print error rate for test data set
185-
print 'Pass', pass_id, ' test evaluator: ', test_evaluator
197+
print 'Test evaluator: ', test_evaluator
186198
test_evaluator.finish()
187-
updater.restore()
199+
context.updater.restore()
188200

189-
updater.catchUpWith()
190-
params = m.getParameters()
201+
context.updater.catchUpWith()
202+
params = context.gradient_machine.getParameters()
191203
for each_param in params:
192204
assert isinstance(each_param, api.Parameter)
193205
value = each_param.getBuf(api.PARAMETER_VALUE)
@@ -196,9 +208,25 @@ def main():
196208
# Here, we could save parameter to every where you want
197209
print each_param.getName(), value
198210

199-
updater.finishPass()
211+
context.updater.finishPass()
212+
213+
def on_batch_end(self, context, next_callback):
214+
# Finish batch.
215+
# * will clear gradient.
216+
# * ensure all values should be updated.
217+
context.updater.finishBatch(context.cost)
218+
return False
200219

201-
m.finish()
220+
def __init__(self):
221+
RunnerChainItem.__init__(self)
222+
223+
224+
def main():
225+
runner = Runner()
226+
runner.add_chain_item(MonolithicChainItem())
227+
with runner.use():
228+
for _ in xrange(2):
229+
runner.run_one_pass()
202230

203231

204232
if __name__ == '__main__':

demo/mnist/trainer.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import functools
2+
3+
__all__ = ['RunnerChainItem', 'Runner']
4+
5+
6+
class RunnerChainItem(object):
7+
def __init__(self):
8+
pass
9+
10+
def initialize(self, context, next_callback):
11+
next_callback(context)
12+
13+
def finalize(self, context, next_callback):
14+
next_callback(context)
15+
16+
def on_pass_begin(self, context, next_callback):
17+
next_callback(context)
18+
19+
def on_pass_end(self, context, next_callback):
20+
next_callback(context)
21+
22+
def on_batch_begin(self, context, next_callback):
23+
return next_callback(context)
24+
25+
def on_batch_end(self, context, next_callback):
26+
return next_callback(context)
27+
28+
29+
def default_next_callback(*args, **kwargs):
30+
return False
31+
32+
33+
class RunnerContext(object):
34+
pass
35+
36+
37+
class RunnerSection(object):
38+
def __init__(self, runner):
39+
self.runner = runner
40+
41+
def __enter__(self):
42+
self.runner.initialize()
43+
44+
def __exit__(self, exc_type, exc_val, exc_tb):
45+
self.runner.finalize()
46+
47+
48+
class Runner(object):
49+
def __init__(self):
50+
self.chains = []
51+
52+
self.begin_pass = None
53+
self.end_pass = None
54+
self.begin_batch = None
55+
self.end_batch = None
56+
self.finalize = None
57+
58+
self.context = RunnerContext()
59+
self.context.runner = self
60+
61+
def add_chain_item(self, item):
62+
assert isinstance(item, RunnerChainItem)
63+
self.chains.append(item)
64+
65+
def initialize(self):
66+
if None not in [
67+
self.begin_pass, self.end_pass, self.begin_batch,
68+
self.end_batch, self.finalize
69+
]:
70+
return False
71+
else:
72+
assert len(self.chains) != 0
73+
actual_init = default_next_callback
74+
self.begin_pass = default_next_callback
75+
self.end_pass = default_next_callback
76+
self.begin_batch = default_next_callback
77+
self.end_batch = default_next_callback
78+
self.finalize = default_next_callback
79+
80+
for chain in reversed(self.chains):
81+
assert isinstance(chain, RunnerChainItem)
82+
actual_init = functools.partial(
83+
chain.initialize, next_callback=actual_init)
84+
self.begin_pass = functools.partial(
85+
chain.on_pass_begin, next_callback=self.begin_pass)
86+
self.end_pass = functools.partial(
87+
chain.on_pass_end, next_callback=self.end_pass)
88+
self.begin_batch = functools.partial(
89+
chain.on_batch_begin, next_callback=self.begin_batch)
90+
self.end_batch = functools.partial(
91+
chain.on_batch_end, next_callback=self.end_batch)
92+
self.finalize = functools.partial(
93+
chain.finalize, next_callback=self.finalize)
94+
95+
actual_init(self.context)
96+
return True
97+
98+
def run_one_pass(self):
99+
self.begin_pass(self.context)
100+
exit_flag = False
101+
while not exit_flag:
102+
exit_flag = self.begin_batch(self.context)
103+
if exit_flag:
104+
break
105+
exit_flag = self.end_batch(self.context)
106+
self.end_pass(self.context)
107+
108+
def use(self):
109+
return RunnerSection(self)

0 commit comments

Comments
 (0)