@@ -53,7 +53,7 @@ def input_order_converter(generator):
53
53
54
54
55
55
def main ():
56
- api .initPaddle ("-use_gpu=true " , "-trainer_count=4" ) # use 4 cpu cores
56
+ api .initPaddle ("-use_gpu=false " , "-trainer_count=4" ) # use 4 cpu cores
57
57
config = paddle .trainer .config_parser .parse_config (
58
58
'simple_mnist_network.py' , '' )
59
59
@@ -106,7 +106,7 @@ def main():
106
106
# TrainData will stored in a data pool. Currently implementation is not care
107
107
# about memory, speed. Just a very naive implementation.
108
108
train_data_generator = input_order_converter (read_from_mnist (train_file ))
109
- train_data = BatchPool (train_data_generator , 128 )
109
+ train_data = BatchPool (train_data_generator , 512 )
110
110
111
111
# outArgs is Neural Network forward result. Here is not useful, just passed
112
112
# to gradient_machine.forward
@@ -126,16 +126,13 @@ def main():
126
126
# batch_evaluator can be used between start/finish.
127
127
batch_evaluator .start ()
128
128
129
- # A callback when backward.
130
- # It is used for updating weight values vy calculated Gradient.
131
- def updater_callback (param ):
132
- updater .update (param )
133
-
134
129
# forwardBackward is a shortcut for forward and backward.
135
130
# It is sometimes faster than invoke forward/backward separately,
136
131
# because in GradientMachine, it may be async.
137
- m .forwardBackward (
138
- converter (data_batch ), outArgs , pass_type , updater_callback )
132
+ m .forwardBackward (converter (data_batch ), outArgs , pass_type )
133
+
134
+ for each_param in m .getParameters ():
135
+ updater .update (each_param )
139
136
140
137
# Get cost. We use numpy to calculate total cost for this batch.
141
138
cost_vec = outArgs .getSlotValue (0 )
@@ -159,7 +156,7 @@ def updater_callback(param):
159
156
updater .apply ()
160
157
test_evaluator .start ()
161
158
test_data_generator = input_order_converter (read_from_mnist (test_file ))
162
- for data_batch in generator_to_batch (test_data_generator , 128 ):
159
+ for data_batch in generator_to_batch (test_data_generator , 512 ):
163
160
# in testing stage, only forward is needed.
164
161
m .forward (converter (data_batch ), outArgs , api .PASS_TEST )
165
162
m .eval (test_evaluator )
0 commit comments