14
14
from mnist_util import read_from_mnist
15
15
from paddle .trainer_config_helpers import *
16
16
17
+ from trainer import *
18
+
17
19
18
20
def optimizer_config ():
19
21
settings (
@@ -72,122 +74,132 @@ def input_order_converter(generator):
72
74
yield each_item ['pixel' ], each_item ['label' ]
73
75
74
76
75
- def main ():
76
- api .initPaddle ("-use_gpu=false" , "-trainer_count=4" ) # use 4 cpu cores
77
-
78
- # get enable_types for each optimizer.
79
- # enable_types = [value, gradient, momentum, etc]
80
- # For each optimizer(SGD, Adam), GradientMachine should enable different
81
- # buffers.
82
- opt_config_proto = parse_optimizer_config (optimizer_config )
83
- opt_config = api .OptimizationConfig .createFromProto (opt_config_proto )
84
- _temp_optimizer_ = api .ParameterOptimizer .create (opt_config )
85
- enable_types = _temp_optimizer_ .getParameterTypes ()
86
-
87
- # Create Simple Gradient Machine.
88
- model_config = parse_network_config (network_config )
89
- m = api .GradientMachine .createFromConfigProto (
90
- model_config , api .CREATE_MODE_NORMAL , enable_types )
91
-
92
- # This type check is not useful. Only enable type hint in IDE.
93
- # Such as PyCharm
94
- assert isinstance (m , api .GradientMachine )
95
-
96
- # Initialize Parameter by numpy.
97
- init_parameter (network = m )
98
-
99
- # Create Local Updater. Local means not run in cluster.
100
- # For a cluster training, here we can change to createRemoteUpdater
101
- # in future.
102
- updater = api .ParameterUpdater .createLocalUpdater (opt_config )
103
- assert isinstance (updater , api .ParameterUpdater )
104
-
105
- # Initialize ParameterUpdater.
106
- updater .init (m )
107
-
108
- # DataProvider Converter is a utility convert Python Object to Paddle C++
109
- # Input. The input format is as same as Paddle's DataProvider.
110
- converter = DataProviderConverter (
111
- input_types = [dp .dense_vector (784 ), dp .integer_value (10 )])
112
-
113
- train_file = './data/raw_data/train'
114
- test_file = './data/raw_data/t10k'
115
-
116
- # start gradient machine.
117
- # the gradient machine must be started before invoke forward/backward.
118
- # not just for training, but also for inference.
119
- m .start ()
120
-
121
- # evaluator can print error rate, etc. It is a C++ class.
122
- batch_evaluator = m .makeEvaluator ()
123
- test_evaluator = m .makeEvaluator ()
124
-
125
- # Get Train Data.
126
- # TrainData will stored in a data pool. Currently implementation is not care
127
- # about memory, speed. Just a very naive implementation.
128
- train_data_generator = input_order_converter (read_from_mnist (train_file ))
129
- train_data = BatchPool (train_data_generator , 512 )
130
-
131
- # outArgs is Neural Network forward result. Here is not useful, just passed
132
- # to gradient_machine.forward
133
- outArgs = api .Arguments .createArguments (0 )
134
-
135
- for pass_id in xrange (2 ): # we train 2 passes.
136
- updater .startPass ()
137
-
138
- for batch_id , data_batch in enumerate (train_data ()):
139
- # data_batch is input images.
140
- # here, for online learning, we could get data_batch from network.
141
-
142
- # Start update one batch.
143
- pass_type = updater .startBatch (len (data_batch ))
144
-
145
- # Start BatchEvaluator.
146
- # batch_evaluator can be used between start/finish.
147
- batch_evaluator .start ()
148
-
149
- # forwardBackward is a shortcut for forward and backward.
150
- # It is sometimes faster than invoke forward/backward separately,
151
- # because in GradientMachine, it may be async.
152
- m .forwardBackward (converter (data_batch ), outArgs , pass_type )
153
-
154
- for each_param in m .getParameters ():
155
- updater .update (each_param )
156
-
157
- # Get cost. We use numpy to calculate total cost for this batch.
158
- cost_vec = outArgs .getSlotValue (0 )
159
- cost_vec = cost_vec .copyToNumpyMat ()
160
- cost = cost_vec .sum () / len (data_batch )
161
-
162
- # Make evaluator works.
163
- m .eval (batch_evaluator )
164
-
165
- # Print logs.
166
- print 'Pass id' , pass_id , 'Batch id' , batch_id , 'with cost=' , \
167
- cost , batch_evaluator
168
-
169
- batch_evaluator .finish ()
170
- # Finish batch.
171
- # * will clear gradient.
172
- # * ensure all values should be updated.
173
- updater .finishBatch (cost )
174
-
77
+ class MonolithicChainItem (RunnerChainItem ):
78
+ def finalize (self , context , next_callback ):
79
+ context .gradient_machine .finish ()
80
+
81
+ def initialize (self , context , next_callback ):
82
+ api .initPaddle ("-use_gpu=false" , "-trainer_count=4" ) # use 4 cpu cores
83
+
84
+ # get enable_types for each optimizer.
85
+ # enable_types = [value, gradient, momentum, etc]
86
+ # For each optimizer(SGD, Adam), GradientMachine should enable different
87
+ # buffers.
88
+ opt_config_proto = parse_optimizer_config (optimizer_config )
89
+ opt_config = api .OptimizationConfig .createFromProto (opt_config_proto )
90
+ _temp_optimizer_ = api .ParameterOptimizer .create (opt_config )
91
+ enable_types = _temp_optimizer_ .getParameterTypes ()
92
+
93
+ # Create Simple Gradient Machine.
94
+ model_config = parse_network_config (network_config )
95
+ context .gradient_machine = api .GradientMachine .createFromConfigProto (
96
+ model_config , api .CREATE_MODE_NORMAL , enable_types )
97
+
98
+ # This type check is not useful. Only enable type hint in IDE.
99
+ # Such as PyCharm
100
+ assert isinstance (context .gradient_machine , api .GradientMachine )
101
+
102
+ # Initialize Parameter by numpy.
103
+ init_parameter (network = context .gradient_machine )
104
+
105
+ # Create Local Updater. Local means not run in cluster.
106
+ # For a cluster training, here we can change to createRemoteUpdater
107
+ # in future.
108
+ context .updater = api .ParameterUpdater .createLocalUpdater (opt_config )
109
+ assert isinstance (context .updater , api .ParameterUpdater )
110
+ context .updater .init (context .gradient_machine )
111
+
112
+ # DataProvider Converter is a utility convert Python Object to Paddle C++
113
+ # Input. The input format is as same as Paddle's DataProvider.
114
+ context .data_converter = DataProviderConverter (
115
+ input_types = [dp .dense_vector (784 ), dp .integer_value (10 )])
116
+
117
+ train_file = './data/raw_data/train'
118
+ test_file = './data/raw_data/t10k'
119
+
120
+ context .gradient_machine .start ()
121
+
122
+ # Get Train Data.
123
+ # TrainData will stored in a data pool. Currently implementation is not care
124
+ # about memory, speed. Just a very naive implementation.
125
+ train_data_generator = input_order_converter (
126
+ read_from_mnist (train_file ))
127
+ train_data = BatchPool (train_data_generator , 512 )
128
+ context .train_data_callback = train_data
129
+ context .test_file = test_file
130
+
131
+ next_callback (context )
132
+
133
+ def on_batch_begin (self , context , next_callback ):
134
+ batch_evaluator = context .gradient_machine .makeEvaluator ()
135
+ # outArgs is Neural Network forward result. Here is not useful, just passed
136
+ # to gradient_machine.forward
137
+ outArgs = api .Arguments .createArguments (0 )
138
+
139
+ try :
140
+ data_batch = next (context .train_data )
141
+ except StopIteration :
142
+ return True
143
+
144
+ # data_batch is input images.
145
+ # here, for online learning, we could get data_batch from network.
146
+
147
+ # Start update one batch.
148
+ pass_type = context .updater .startBatch (len (data_batch ))
149
+
150
+ # Start BatchEvaluator.
151
+ # batch_evaluator can be used between start/finish.
152
+ batch_evaluator .start ()
153
+
154
+ # forwardBackward is a shortcut for forward and backward.
155
+ # It is sometimes faster than invoke forward/backward separately,
156
+ # because in GradientMachine, it may be async.
157
+ context .gradient_machine .forwardBackward (
158
+ context .data_converter (data_batch ), outArgs , pass_type )
159
+
160
+ for each_param in context .gradient_machine .getParameters ():
161
+ context .updater .update (each_param )
162
+
163
+ # Get cost. We use numpy to calculate total cost for this batch.
164
+ cost_vec = outArgs .getSlotValue (0 )
165
+ cost_vec = cost_vec .copyToNumpyMat ()
166
+ cost = cost_vec .sum () / len (data_batch )
167
+
168
+ # Make evaluator works.
169
+ context .gradient_machine .eval (batch_evaluator )
170
+
171
+ # Print logs.
172
+ print 'batch with cost=' , cost , batch_evaluator
173
+
174
+ batch_evaluator .finish ()
175
+ context .cost = cost
176
+ return False
177
+
178
+ def on_pass_begin (self , context , next_callback ):
179
+ context .updater .startPass ()
180
+ context .train_data = context .train_data_callback ()
181
+
182
+ def on_pass_end (self , context , next_callback ):
175
183
# testing stage. use test data set to test current network.
176
- updater .apply ()
184
+ outArgs = api .Arguments .createArguments (0 )
185
+ context .updater .apply ()
186
+ test_evaluator = context .gradient_machine .makeEvaluator ()
177
187
test_evaluator .start ()
178
- test_data_generator = input_order_converter (read_from_mnist (test_file ))
188
+ test_data_generator = input_order_converter (
189
+ read_from_mnist (context .test_file ))
179
190
for data_batch in generator_to_batch (test_data_generator , 512 ):
180
191
# in testing stage, only forward is needed.
181
- m .forward (converter (data_batch ), outArgs , api .PASS_TEST )
182
- m .eval (test_evaluator )
192
+ context .gradient_machine .forward (
193
+ context .data_converter (data_batch ), outArgs , api .PASS_TEST )
194
+ context .gradient_machine .eval (test_evaluator )
183
195
184
196
# print error rate for test data set
185
- print 'Pass' , pass_id , ' test evaluator: ' , test_evaluator
197
+ print 'Test evaluator: ' , test_evaluator
186
198
test_evaluator .finish ()
187
- updater .restore ()
199
+ context . updater .restore ()
188
200
189
- updater .catchUpWith ()
190
- params = m .getParameters ()
201
+ context . updater .catchUpWith ()
202
+ params = context . gradient_machine .getParameters ()
191
203
for each_param in params :
192
204
assert isinstance (each_param , api .Parameter )
193
205
value = each_param .getBuf (api .PARAMETER_VALUE )
@@ -196,9 +208,25 @@ def main():
196
208
# Here, we could save parameter to every where you want
197
209
print each_param .getName (), value
198
210
199
- updater .finishPass ()
211
+ context .updater .finishPass ()
212
+
213
+ def on_batch_end (self , context , next_callback ):
214
+ # Finish batch.
215
+ # * will clear gradient.
216
+ # * ensure all values should be updated.
217
+ context .updater .finishBatch (context .cost )
218
+ return False
200
219
201
- m .finish ()
220
+ def __init__ (self ):
221
+ RunnerChainItem .__init__ (self )
222
+
223
+
224
+ def main ():
225
+ runner = Runner ()
226
+ runner .add_chain_item (MonolithicChainItem ())
227
+ with runner .use ():
228
+ for _ in xrange (2 ):
229
+ runner .run_one_pass ()
202
230
203
231
204
232
if __name__ == '__main__' :
0 commit comments