Skip to content

Commit 2be7ec9

Browse files
authored
Merge pull request #713 from wangyang59/improveMnistDemo
improve demo/mnist dataProvider speed
2 parents 04fb1fc + 828303b commit 2be7ec9

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

demo/mnist/mnist_provider.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from paddle.trainer.PyDataProvider2 import *
2+
import numpy
23

34

45
# Define a py data provider
56
@provider(
67
input_types={'pixel': dense_vector(28 * 28),
7-
'label': integer_value(10)})
8+
'label': integer_value(10)},
9+
cache=CacheType.CACHE_PASS_IN_MEM)
810
def process(settings, filename): # settings is not used currently.
911
imgf = filename + "-images-idx3-ubyte"
1012
labelf = filename + "-labels-idx1-ubyte"
@@ -20,12 +22,13 @@ def process(settings, filename): # settings is not used currently.
2022
else:
2123
n = 10000
2224

23-
for i in range(n):
24-
label = ord(l.read(1))
25-
pixels = []
26-
for j in range(28 * 28):
27-
pixels.append(float(ord(f.read(1))) / 255.0)
28-
yield {"pixel": pixels, 'label': label}
25+
images = numpy.fromfile(
26+
f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32')
27+
images = images / 255.0 * 2.0 - 1.0
28+
labels = numpy.fromfile(l, 'ubyte', count=n).astype("int")
29+
30+
for i in xrange(n):
31+
yield {"pixel": images[i, :], 'label': labels[i]}
2932

3033
f.close()
3134
l.close()

0 commit comments

Comments
 (0)