Skip to content

Commit 446fccf

Browse files
committed
Add network decorator for network defines.
* Extract NewGradientMachine, ParamUpdater, DataProvider.
1 parent aba84aa commit 446fccf

File tree

2 files changed

+336
-116
lines changed

2 files changed

+336
-116
lines changed

demo/mnist/api_train.py

Lines changed: 35 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -6,44 +6,31 @@
66
77
The user api could be simpler and carefully designed.
88
"""
9-
import py_paddle.swig_paddle as api
10-
from py_paddle import DataProviderConverter
9+
10+
import mnist_provider
1111
import paddle.trainer.PyDataProvider2 as dp
12-
import numpy as np
13-
import random
12+
import py_paddle.swig_paddle as api
1413
from mnist_util import read_from_mnist
1514
from paddle.trainer_config_helpers import *
16-
1715
from trainer import *
1816

1917

20-
def optimizer_config():
21-
settings(
22-
learning_rate=1e-4,
23-
learning_method=AdamOptimizer(),
24-
batch_size=1000,
25-
model_average=ModelAverage(average_window=0.5),
26-
regularization=L2Regularization(rate=0.5))
27-
28-
29-
def network_config():
30-
imgs = data_layer(name='pixel', size=784)
31-
hidden1 = fc_layer(input=imgs, size=200)
18+
@network(
19+
inputs={
20+
'pixel': dp.dense_vector(784),
21+
'label': dp.integer_value(10),
22+
},
23+
learning_rate=1e-4,
24+
learning_method=AdamOptimizer(),
25+
batch_size=1000,
26+
model_average=ModelAverage(average_window=0.5),
27+
regularization=L2Regularization(rate=0.5))
28+
def mnist_network(pixel, label):
29+
hidden1 = fc_layer(input=pixel, size=200)
3230
hidden2 = fc_layer(input=hidden1, size=200)
3331
inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation())
34-
cost = classification_cost(
35-
input=inference, label=data_layer(
36-
name='label', size=10))
37-
outputs(cost)
38-
39-
40-
def init_parameter(network):
41-
assert isinstance(network, api.GradientMachine)
42-
for each_param in network.getParameters():
43-
assert isinstance(each_param, api.Parameter)
44-
array_size = len(each_param)
45-
array = np.random.uniform(-1.0, 1.0, array_size).astype('float32')
46-
each_param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(array)
32+
cost = classification_cost(input=inference, label=label)
33+
return cost
4734

4835

4936
def generator_to_batch(generator, batch_size):
@@ -57,18 +44,6 @@ def generator_to_batch(generator, batch_size):
5744
yield ret_val
5845

5946

60-
class BatchPool(object):
61-
def __init__(self, generator, batch_size):
62-
self.data = list(generator)
63-
self.batch_size = batch_size
64-
65-
def __call__(self):
66-
random.shuffle(self.data)
67-
for offset in xrange(0, len(self.data), self.batch_size):
68-
limit = min(offset + self.batch_size, len(self.data))
69-
yield self.data[offset:limit]
70-
71-
7247
def input_order_converter(generator):
7348
for each_item in generator:
7449
yield each_item['pixel'], each_item['label']
@@ -79,53 +54,8 @@ def finalize(self, context, next_callback):
7954
context.gradient_machine.finish()
8055

8156
def initialize(self, context, next_callback):
82-
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores
83-
84-
# get enable_types for each optimizer.
85-
# enable_types = [value, gradient, momentum, etc]
86-
# For each optimizer(SGD, Adam), GradientMachine should enable different
87-
# buffers.
88-
opt_config_proto = parse_optimizer_config(optimizer_config)
89-
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto)
90-
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
91-
enable_types = _temp_optimizer_.getParameterTypes()
92-
93-
# Create Simple Gradient Machine.
94-
model_config = parse_network_config(network_config)
95-
context.gradient_machine = api.GradientMachine.createFromConfigProto(
96-
model_config, api.CREATE_MODE_NORMAL, enable_types)
97-
98-
# This type check is not useful. Only enable type hint in IDE.
99-
# Such as PyCharm
100-
assert isinstance(context.gradient_machine, api.GradientMachine)
101-
102-
# Initialize Parameter by numpy.
103-
init_parameter(network=context.gradient_machine)
104-
105-
# Create Local Updater. Local means not run in cluster.
106-
# For a cluster training, here we can change to createRemoteUpdater
107-
# in future.
108-
context.updater = api.ParameterUpdater.createLocalUpdater(opt_config)
109-
assert isinstance(context.updater, api.ParameterUpdater)
110-
context.updater.init(context.gradient_machine)
111-
112-
# DataProvider Converter is a utility convert Python Object to Paddle C++
113-
# Input. The input format is as same as Paddle's DataProvider.
114-
context.data_converter = DataProviderConverter(
115-
input_types=[dp.dense_vector(784), dp.integer_value(10)])
116-
117-
train_file = './data/raw_data/train'
11857
test_file = './data/raw_data/t10k'
11958

120-
context.gradient_machine.start()
121-
122-
# Get Train Data.
123-
# TrainData will stored in a data pool. Currently implementation is not care
124-
# about memory, speed. Just a very naive implementation.
125-
train_data_generator = input_order_converter(
126-
read_from_mnist(train_file))
127-
train_data = BatchPool(train_data_generator, 512)
128-
context.train_data_callback = train_data
12959
context.test_file = test_file
13060

13161
next_callback(context)
@@ -136,34 +66,24 @@ def on_batch_begin(self, context, next_callback):
13666
# to gradient_machine.forward
13767
outArgs = api.Arguments.createArguments(0)
13868

139-
try:
140-
data_batch = next(context.train_data)
141-
except StopIteration:
142-
return True
143-
144-
# data_batch is input images.
145-
# here, for online learning, we could get data_batch from network.
146-
147-
# Start update one batch.
148-
pass_type = context.updater.startBatch(len(data_batch))
149-
15069
# Start BatchEvaluator.
15170
# batch_evaluator can be used between start/finish.
15271
batch_evaluator.start()
15372

15473
# forwardBackward is a shortcut for forward and backward.
15574
# It is sometimes faster than invoke forward/backward separately,
15675
# because in GradientMachine, it may be async.
157-
context.gradient_machine.forwardBackward(
158-
context.data_converter(data_batch), outArgs, pass_type)
76+
context.gradient_machine.forwardBackward(context.in_args, outArgs,
77+
api.PASS_TRAIN)
15978

16079
for each_param in context.gradient_machine.getParameters():
16180
context.updater.update(each_param)
16281

16382
# Get cost. We use numpy to calculate total cost for this batch.
16483
cost_vec = outArgs.getSlotValue(0)
16584
cost_vec = cost_vec.copyToNumpyMat()
166-
cost = cost_vec.sum() / len(data_batch)
85+
cost = cost_vec.sum() / context.current_batch_size
86+
context.current_cost = cost
16787

16888
# Make evaluator works.
16989
context.gradient_machine.eval(batch_evaluator)
@@ -175,10 +95,6 @@ def on_batch_begin(self, context, next_callback):
17595
context.cost = cost
17696
return False
17797

178-
def on_pass_begin(self, context, next_callback):
179-
context.updater.startPass()
180-
context.train_data = context.train_data_callback()
181-
18298
def on_pass_end(self, context, next_callback):
18399
# testing stage. use test data set to test current network.
184100
outArgs = api.Arguments.createArguments(0)
@@ -208,21 +124,26 @@ def on_pass_end(self, context, next_callback):
208124
# Here, we could save parameter to every where you want
209125
print each_param.getName(), value
210126

211-
context.updater.finishPass()
212-
213-
def on_batch_end(self, context, next_callback):
214-
# Finish batch.
215-
# * will clear gradient.
216-
# * ensure all values should be updated.
217-
context.updater.finishBatch(context.cost)
218-
return False
219-
220127
def __init__(self):
221128
RunnerChainItem.__init__(self)
222129

223130

224131
def main():
132+
mnist = mnist_network()
133+
225134
runner = Runner()
135+
runner.add_chain_item(DeviceChainItem(use_gpu=False, device_count=4))
136+
137+
runner.add_chain_item(CreateGradientMachine(network=mnist))
138+
runner.add_chain_item(RandomInitializeParams())
139+
runner.add_chain_item(
140+
BasicTrainerDataProvider(
141+
network=mnist,
142+
method=mnist_provider.process,
143+
file_list=['./data/raw_data/train'],
144+
batch_size=256))
145+
runner.add_chain_item(BasicLocalParameterUpdater(network=mnist))
146+
226147
runner.add_chain_item(MonolithicChainItem())
227148
with runner.use():
228149
for _ in xrange(2):

0 commit comments

Comments
 (0)