6
6
7
7
The user api could be simpler and carefully designed.
8
8
"""
9
- import py_paddle . swig_paddle as api
10
- from py_paddle import DataProviderConverter
9
+
10
+ import mnist_provider
11
11
import paddle .trainer .PyDataProvider2 as dp
12
- import numpy as np
13
- import random
12
+ import py_paddle .swig_paddle as api
14
13
from mnist_util import read_from_mnist
15
14
from paddle .trainer_config_helpers import *
16
-
17
15
from trainer import *
18
16
19
17
20
- def optimizer_config ():
21
- settings (
22
- learning_rate = 1e-4 ,
23
- learning_method = AdamOptimizer ( ),
24
- batch_size = 1000 ,
25
- model_average = ModelAverage ( average_window = 0.5 ) ,
26
- regularization = L2Regularization ( rate = 0.5 ))
27
-
28
-
29
- def network_config ():
30
- imgs = data_layer ( name = ' pixel' , size = 784 )
31
- hidden1 = fc_layer (input = imgs , size = 200 )
18
+ @ network (
19
+ inputs = {
20
+ 'pixel' : dp . dense_vector ( 784 ) ,
21
+ 'label' : dp . integer_value ( 10 ),
22
+ } ,
23
+ learning_rate = 1e-4 ,
24
+ learning_method = AdamOptimizer (),
25
+ batch_size = 1000 ,
26
+ model_average = ModelAverage ( average_window = 0.5 ),
27
+ regularization = L2Regularization ( rate = 0.5 ))
28
+ def mnist_network ( pixel , label ):
29
+ hidden1 = fc_layer (input = pixel , size = 200 )
32
30
hidden2 = fc_layer (input = hidden1 , size = 200 )
33
31
inference = fc_layer (input = hidden2 , size = 10 , act = SoftmaxActivation ())
34
- cost = classification_cost (
35
- input = inference , label = data_layer (
36
- name = 'label' , size = 10 ))
37
- outputs (cost )
38
-
39
-
40
- def init_parameter (network ):
41
- assert isinstance (network , api .GradientMachine )
42
- for each_param in network .getParameters ():
43
- assert isinstance (each_param , api .Parameter )
44
- array_size = len (each_param )
45
- array = np .random .uniform (- 1.0 , 1.0 , array_size ).astype ('float32' )
46
- each_param .getBuf (api .PARAMETER_VALUE ).copyFromNumpyArray (array )
32
+ cost = classification_cost (input = inference , label = label )
33
+ return cost
47
34
48
35
49
36
def generator_to_batch (generator , batch_size ):
@@ -57,18 +44,6 @@ def generator_to_batch(generator, batch_size):
57
44
yield ret_val
58
45
59
46
60
- class BatchPool (object ):
61
- def __init__ (self , generator , batch_size ):
62
- self .data = list (generator )
63
- self .batch_size = batch_size
64
-
65
- def __call__ (self ):
66
- random .shuffle (self .data )
67
- for offset in xrange (0 , len (self .data ), self .batch_size ):
68
- limit = min (offset + self .batch_size , len (self .data ))
69
- yield self .data [offset :limit ]
70
-
71
-
72
47
def input_order_converter (generator ):
73
48
for each_item in generator :
74
49
yield each_item ['pixel' ], each_item ['label' ]
@@ -79,53 +54,8 @@ def finalize(self, context, next_callback):
79
54
context .gradient_machine .finish ()
80
55
81
56
def initialize (self , context , next_callback ):
82
- api .initPaddle ("-use_gpu=false" , "-trainer_count=4" ) # use 4 cpu cores
83
-
84
- # get enable_types for each optimizer.
85
- # enable_types = [value, gradient, momentum, etc]
86
- # For each optimizer(SGD, Adam), GradientMachine should enable different
87
- # buffers.
88
- opt_config_proto = parse_optimizer_config (optimizer_config )
89
- opt_config = api .OptimizationConfig .createFromProto (opt_config_proto )
90
- _temp_optimizer_ = api .ParameterOptimizer .create (opt_config )
91
- enable_types = _temp_optimizer_ .getParameterTypes ()
92
-
93
- # Create Simple Gradient Machine.
94
- model_config = parse_network_config (network_config )
95
- context .gradient_machine = api .GradientMachine .createFromConfigProto (
96
- model_config , api .CREATE_MODE_NORMAL , enable_types )
97
-
98
- # This type check is not useful. Only enable type hint in IDE.
99
- # Such as PyCharm
100
- assert isinstance (context .gradient_machine , api .GradientMachine )
101
-
102
- # Initialize Parameter by numpy.
103
- init_parameter (network = context .gradient_machine )
104
-
105
- # Create Local Updater. Local means not run in cluster.
106
- # For a cluster training, here we can change to createRemoteUpdater
107
- # in future.
108
- context .updater = api .ParameterUpdater .createLocalUpdater (opt_config )
109
- assert isinstance (context .updater , api .ParameterUpdater )
110
- context .updater .init (context .gradient_machine )
111
-
112
- # DataProvider Converter is a utility convert Python Object to Paddle C++
113
- # Input. The input format is as same as Paddle's DataProvider.
114
- context .data_converter = DataProviderConverter (
115
- input_types = [dp .dense_vector (784 ), dp .integer_value (10 )])
116
-
117
- train_file = './data/raw_data/train'
118
57
test_file = './data/raw_data/t10k'
119
58
120
- context .gradient_machine .start ()
121
-
122
- # Get Train Data.
123
- # TrainData will stored in a data pool. Currently implementation is not care
124
- # about memory, speed. Just a very naive implementation.
125
- train_data_generator = input_order_converter (
126
- read_from_mnist (train_file ))
127
- train_data = BatchPool (train_data_generator , 512 )
128
- context .train_data_callback = train_data
129
59
context .test_file = test_file
130
60
131
61
next_callback (context )
@@ -136,34 +66,24 @@ def on_batch_begin(self, context, next_callback):
136
66
# to gradient_machine.forward
137
67
outArgs = api .Arguments .createArguments (0 )
138
68
139
- try :
140
- data_batch = next (context .train_data )
141
- except StopIteration :
142
- return True
143
-
144
- # data_batch is input images.
145
- # here, for online learning, we could get data_batch from network.
146
-
147
- # Start update one batch.
148
- pass_type = context .updater .startBatch (len (data_batch ))
149
-
150
69
# Start BatchEvaluator.
151
70
# batch_evaluator can be used between start/finish.
152
71
batch_evaluator .start ()
153
72
154
73
# forwardBackward is a shortcut for forward and backward.
155
74
# It is sometimes faster than invoke forward/backward separately,
156
75
# because in GradientMachine, it may be async.
157
- context .gradient_machine .forwardBackward (
158
- context . data_converter ( data_batch ), outArgs , pass_type )
76
+ context .gradient_machine .forwardBackward (context . in_args , outArgs ,
77
+ api . PASS_TRAIN )
159
78
160
79
for each_param in context .gradient_machine .getParameters ():
161
80
context .updater .update (each_param )
162
81
163
82
# Get cost. We use numpy to calculate total cost for this batch.
164
83
cost_vec = outArgs .getSlotValue (0 )
165
84
cost_vec = cost_vec .copyToNumpyMat ()
166
- cost = cost_vec .sum () / len (data_batch )
85
+ cost = cost_vec .sum () / context .current_batch_size
86
+ context .current_cost = cost
167
87
168
88
# Make evaluator works.
169
89
context .gradient_machine .eval (batch_evaluator )
@@ -175,10 +95,6 @@ def on_batch_begin(self, context, next_callback):
175
95
context .cost = cost
176
96
return False
177
97
178
- def on_pass_begin (self , context , next_callback ):
179
- context .updater .startPass ()
180
- context .train_data = context .train_data_callback ()
181
-
182
98
def on_pass_end (self , context , next_callback ):
183
99
# testing stage. use test data set to test current network.
184
100
outArgs = api .Arguments .createArguments (0 )
@@ -208,21 +124,26 @@ def on_pass_end(self, context, next_callback):
208
124
# Here, we could save parameter to every where you want
209
125
print each_param .getName (), value
210
126
211
- context .updater .finishPass ()
212
-
213
- def on_batch_end (self , context , next_callback ):
214
- # Finish batch.
215
- # * will clear gradient.
216
- # * ensure all values should be updated.
217
- context .updater .finishBatch (context .cost )
218
- return False
219
-
220
127
def __init__ (self ):
221
128
RunnerChainItem .__init__ (self )
222
129
223
130
224
131
def main ():
132
+ mnist = mnist_network ()
133
+
225
134
runner = Runner ()
135
+ runner .add_chain_item (DeviceChainItem (use_gpu = False , device_count = 4 ))
136
+
137
+ runner .add_chain_item (CreateGradientMachine (network = mnist ))
138
+ runner .add_chain_item (RandomInitializeParams ())
139
+ runner .add_chain_item (
140
+ BasicTrainerDataProvider (
141
+ network = mnist ,
142
+ method = mnist_provider .process ,
143
+ file_list = ['./data/raw_data/train' ],
144
+ batch_size = 256 ))
145
+ runner .add_chain_item (BasicLocalParameterUpdater (network = mnist ))
146
+
226
147
runner .add_chain_item (MonolithicChainItem ())
227
148
with runner .use ():
228
149
for _ in xrange (2 ):
0 commit comments