|
| 1 | +""" |
| 2 | +A very basic example for how to use current Raw SWIG API to train mnist network. |
| 3 | +
|
| 4 | +Current implementation uses Raw SWIG, which means the API call is directly \ |
| 5 | +passed to C++ side of Paddle. |
| 6 | +
|
| 7 | +The user api could be simpler and carefully designed. |
| 8 | +""" |
| 9 | +import py_paddle.swig_paddle as api |
| 10 | +from py_paddle import DataProviderConverter |
| 11 | +import paddle.trainer.PyDataProvider2 as dp |
| 12 | +import numpy as np |
| 13 | +import random |
| 14 | +from mnist_util import read_from_mnist |
| 15 | +from paddle.trainer_config_helpers import * |
| 16 | + |
| 17 | + |
| 18 | +def optimizer_config(): |
| 19 | + settings( |
| 20 | + learning_rate=1e-4, |
| 21 | + learning_method=AdamOptimizer(), |
| 22 | + batch_size=1000, |
| 23 | + model_average=ModelAverage(average_window=0.5), |
| 24 | + regularization=L2Regularization(rate=0.5)) |
| 25 | + |
| 26 | + |
| 27 | +def network_config(): |
| 28 | + imgs = data_layer(name='pixel', size=784) |
| 29 | + hidden1 = fc_layer(input=imgs, size=200) |
| 30 | + hidden2 = fc_layer(input=hidden1, size=200) |
| 31 | + inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation()) |
| 32 | + cost = classification_cost( |
| 33 | + input=inference, label=data_layer( |
| 34 | + name='label', size=10)) |
| 35 | + outputs(cost) |
| 36 | + |
| 37 | + |
| 38 | +def init_parameter(network): |
| 39 | + assert isinstance(network, api.GradientMachine) |
| 40 | + for each_param in network.getParameters(): |
| 41 | + assert isinstance(each_param, api.Parameter) |
| 42 | + array_size = len(each_param) |
| 43 | + array = np.random.uniform(-1.0, 1.0, array_size).astype('float32') |
| 44 | + each_param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(array) |
| 45 | + |
| 46 | + |
| 47 | +def generator_to_batch(generator, batch_size): |
| 48 | + ret_val = list() |
| 49 | + for each_item in generator: |
| 50 | + ret_val.append(each_item) |
| 51 | + if len(ret_val) == batch_size: |
| 52 | + yield ret_val |
| 53 | + ret_val = list() |
| 54 | + if len(ret_val) != 0: |
| 55 | + yield ret_val |
| 56 | + |
| 57 | + |
| 58 | +class BatchPool(object): |
| 59 | + def __init__(self, generator, batch_size): |
| 60 | + self.data = list(generator) |
| 61 | + self.batch_size = batch_size |
| 62 | + |
| 63 | + def __call__(self): |
| 64 | + random.shuffle(self.data) |
| 65 | + for offset in xrange(0, len(self.data), self.batch_size): |
| 66 | + limit = min(offset + self.batch_size, len(self.data)) |
| 67 | + yield self.data[offset:limit] |
| 68 | + |
| 69 | + |
| 70 | +def input_order_converter(generator): |
| 71 | + for each_item in generator: |
| 72 | + yield each_item['pixel'], each_item['label'] |
| 73 | + |
| 74 | + |
| 75 | +def main(): |
| 76 | + api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores |
| 77 | + |
| 78 | + # get enable_types for each optimizer. |
| 79 | + # enable_types = [value, gradient, momentum, etc] |
| 80 | + # For each optimizer(SGD, Adam), GradientMachine should enable different |
| 81 | + # buffers. |
| 82 | + opt_config_proto = parse_optimizer_config(optimizer_config) |
| 83 | + opt_config = api.OptimizationConfig.createFromProto(opt_config_proto) |
| 84 | + _temp_optimizer_ = api.ParameterOptimizer.create(opt_config) |
| 85 | + enable_types = _temp_optimizer_.getParameterTypes() |
| 86 | + |
| 87 | + # Create Simple Gradient Machine. |
| 88 | + model_config = parse_network_config(network_config) |
| 89 | + m = api.GradientMachine.createFromConfigProto( |
| 90 | + model_config, api.CREATE_MODE_NORMAL, enable_types) |
| 91 | + |
| 92 | + # This type check is not useful. Only enable type hint in IDE. |
| 93 | + # Such as PyCharm |
| 94 | + assert isinstance(m, api.GradientMachine) |
| 95 | + |
| 96 | + # Initialize Parameter by numpy. |
| 97 | + init_parameter(network=m) |
| 98 | + |
| 99 | + # Create Local Updater. Local means not run in cluster. |
| 100 | + # For a cluster training, here we can change to createRemoteUpdater |
| 101 | + # in future. |
| 102 | + updater = api.ParameterUpdater.createLocalUpdater(opt_config) |
| 103 | + assert isinstance(updater, api.ParameterUpdater) |
| 104 | + |
| 105 | + # Initialize ParameterUpdater. |
| 106 | + updater.init(m) |
| 107 | + |
| 108 | + # DataProvider Converter is a utility convert Python Object to Paddle C++ |
| 109 | + # Input. The input format is as same as Paddle's DataProvider. |
| 110 | + converter = DataProviderConverter( |
| 111 | + input_types=[dp.dense_vector(784), dp.integer_value(10)]) |
| 112 | + |
| 113 | + train_file = './data/raw_data/train' |
| 114 | + test_file = './data/raw_data/t10k' |
| 115 | + |
| 116 | + # start gradient machine. |
| 117 | + # the gradient machine must be started before invoke forward/backward. |
| 118 | + # not just for training, but also for inference. |
| 119 | + m.start() |
| 120 | + |
| 121 | + # evaluator can print error rate, etc. It is a C++ class. |
| 122 | + batch_evaluator = m.makeEvaluator() |
| 123 | + test_evaluator = m.makeEvaluator() |
| 124 | + |
| 125 | + # Get Train Data. |
| 126 | + # TrainData will stored in a data pool. Currently implementation is not care |
| 127 | + # about memory, speed. Just a very naive implementation. |
| 128 | + train_data_generator = input_order_converter(read_from_mnist(train_file)) |
| 129 | + train_data = BatchPool(train_data_generator, 512) |
| 130 | + |
| 131 | + # outArgs is Neural Network forward result. Here is not useful, just passed |
| 132 | + # to gradient_machine.forward |
| 133 | + outArgs = api.Arguments.createArguments(0) |
| 134 | + |
| 135 | + for pass_id in xrange(2): # we train 2 passes. |
| 136 | + updater.startPass() |
| 137 | + |
| 138 | + for batch_id, data_batch in enumerate(train_data()): |
| 139 | + # data_batch is input images. |
| 140 | + # here, for online learning, we could get data_batch from network. |
| 141 | + |
| 142 | + # Start update one batch. |
| 143 | + pass_type = updater.startBatch(len(data_batch)) |
| 144 | + |
| 145 | + # Start BatchEvaluator. |
| 146 | + # batch_evaluator can be used between start/finish. |
| 147 | + batch_evaluator.start() |
| 148 | + |
| 149 | + # forwardBackward is a shortcut for forward and backward. |
| 150 | + # It is sometimes faster than invoke forward/backward separately, |
| 151 | + # because in GradientMachine, it may be async. |
| 152 | + m.forwardBackward(converter(data_batch), outArgs, pass_type) |
| 153 | + |
| 154 | + for each_param in m.getParameters(): |
| 155 | + updater.update(each_param) |
| 156 | + |
| 157 | + # Get cost. We use numpy to calculate total cost for this batch. |
| 158 | + cost_vec = outArgs.getSlotValue(0) |
| 159 | + cost_vec = cost_vec.copyToNumpyMat() |
| 160 | + cost = cost_vec.sum() / len(data_batch) |
| 161 | + |
| 162 | + # Make evaluator works. |
| 163 | + m.eval(batch_evaluator) |
| 164 | + |
| 165 | + # Print logs. |
| 166 | + print 'Pass id', pass_id, 'Batch id', batch_id, 'with cost=', \ |
| 167 | + cost, batch_evaluator |
| 168 | + |
| 169 | + batch_evaluator.finish() |
| 170 | + # Finish batch. |
| 171 | + # * will clear gradient. |
| 172 | + # * ensure all values should be updated. |
| 173 | + updater.finishBatch(cost) |
| 174 | + |
| 175 | + # testing stage. use test data set to test current network. |
| 176 | + updater.apply() |
| 177 | + test_evaluator.start() |
| 178 | + test_data_generator = input_order_converter(read_from_mnist(test_file)) |
| 179 | + for data_batch in generator_to_batch(test_data_generator, 512): |
| 180 | + # in testing stage, only forward is needed. |
| 181 | + m.forward(converter(data_batch), outArgs, api.PASS_TEST) |
| 182 | + m.eval(test_evaluator) |
| 183 | + |
| 184 | + # print error rate for test data set |
| 185 | + print 'Pass', pass_id, ' test evaluator: ', test_evaluator |
| 186 | + test_evaluator.finish() |
| 187 | + updater.restore() |
| 188 | + |
| 189 | + updater.catchUpWith() |
| 190 | + params = m.getParameters() |
| 191 | + for each_param in params: |
| 192 | + assert isinstance(each_param, api.Parameter) |
| 193 | + value = each_param.getBuf(api.PARAMETER_VALUE) |
| 194 | + value = value.copyToNumpyArray() |
| 195 | + |
| 196 | + # Here, we could save parameter to every where you want |
| 197 | + print each_param.getName(), value |
| 198 | + |
| 199 | + updater.finishPass() |
| 200 | + |
| 201 | + m.finish() |
| 202 | + |
| 203 | + |
| 204 | +if __name__ == '__main__': |
| 205 | + main() |
0 commit comments