Skip to content

Commit 20a9caa

Browse files
committed
Remove MonoChainItem
1 parent 22f4ced commit 20a9caa

File tree

4 files changed

+43
-73
lines changed

4 files changed

+43
-73
lines changed

demo/mnist/api_train.py

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
The user api could be simpler and carefully designed.
88
"""
99

10-
import mnist_provider
1110
import paddle.trainer.PyDataProvider2 as dp
12-
import py_paddle.swig_paddle as api
13-
from mnist_util import read_from_mnist
1411
from paddle.trainer_config_helpers import *
12+
13+
import mnist_provider
1514
from trainer import *
1615

1716

@@ -33,41 +32,6 @@ def mnist_network(pixel, label):
3332
return cost
3433

3534

36-
def generator_to_batch(generator, batch_size):
37-
ret_val = list()
38-
for each_item in generator:
39-
ret_val.append(each_item)
40-
if len(ret_val) == batch_size:
41-
yield ret_val
42-
ret_val = list()
43-
if len(ret_val) != 0:
44-
yield ret_val
45-
46-
47-
def input_order_converter(generator):
48-
for each_item in generator:
49-
yield each_item['pixel'], each_item['label']
50-
51-
52-
class MonolithicChainItem(RunnerChainItem):
53-
def finalize(self, context, next_callback):
54-
context.gradient_machine.finish()
55-
56-
def on_pass_end(self, context, next_callback):
57-
context.updater.catchUpWith()
58-
params = context.gradient_machine.getParameters()
59-
for each_param in params:
60-
assert isinstance(each_param, api.Parameter)
61-
value = each_param.getBuf(api.PARAMETER_VALUE)
62-
value = value.copyToNumpyArray()
63-
64-
# Here, we could save parameter to every where you want
65-
print each_param.getName(), value
66-
67-
def __init__(self):
68-
RunnerChainItem.__init__(self)
69-
70-
7135
def main():
7236
mnist = mnist_network()
7337

@@ -91,7 +55,7 @@ def main():
9155
method=mnist_provider.process,
9256
file_list=['./data/raw_data/t10k'],
9357
batch_size=256))
94-
runner.add_chain_item(MonolithicChainItem())
58+
runner.add_chain_item(SaveParamsOnPassEnd())
9559
with runner.use():
9660
for _ in xrange(2):
9761
runner.run_one_pass()

demo/mnist/mnist_provider.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from paddle.trainer.PyDataProvider2 import *
2-
from mnist_util import read_from_mnist
2+
import numpy
33

44

55
# Define a py data provider
@@ -8,5 +8,27 @@
88
'label': integer_value(10)},
99
cache=CacheType.CACHE_PASS_IN_MEM)
1010
def process(settings, filename): # settings is not used currently.
11-
for each in read_from_mnist(filename):
12-
yield each
11+
imgf = filename + "-images-idx3-ubyte"
12+
labelf = filename + "-labels-idx1-ubyte"
13+
f = open(imgf, "rb")
14+
l = open(labelf, "rb")
15+
16+
f.read(16)
17+
l.read(8)
18+
19+
# Define number of samples for train/test
20+
if "train" in filename:
21+
n = 60000
22+
else:
23+
n = 10000
24+
25+
images = numpy.fromfile(
26+
f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32')
27+
images = images / 255.0 * 2.0 - 1.0
28+
labels = numpy.fromfile(l, 'ubyte', count=n).astype("int")
29+
30+
for i in xrange(n):
31+
yield {"pixel": images[i, :], 'label': labels[i]}
32+
33+
f.close()
34+
l.close()

demo/mnist/mnist_util.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

demo/mnist/trainer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
'RandomInitializeParams', 'BasicLocalParameterUpdater', 'network',
1111
'BasicTrainerDataProvider', 'BasicDataProviderOps',
1212
'BasicGradientMachineTrainOps', 'Counter', 'BatchEvaluate',
13-
'BasicTestDataProvider', 'TestOnPassEnd'
13+
'BasicTestDataProvider', 'TestOnPassEnd', 'SaveParamsOnPassEnd'
1414
]
1515

1616

@@ -580,3 +580,17 @@ def initialize(self, context, next_callback):
580580

581581
BasicTrainerDataProvider = data_provider_creator(True)
582582
BasicTestDataProvider = data_provider_creator(False)
583+
584+
585+
class SaveParamsOnPassEnd(RunnerChainItem):
586+
def __init__(self):
587+
RunnerChainItem.__init__(self)
588+
589+
def on_pass_end(self, context, next_callback):
590+
591+
context.updater.catchUpWith()
592+
params = context.gradient_machine.getParameters()
593+
for param in params:
594+
param.save(param.getName())
595+
596+
next_callback(context)

0 commit comments

Comments
 (0)