Skip to content

Commit ff4e046

Browse files
author
wangyang59
committed
improve demo/mnist dataProvider speed
1 parent 5ac16e5 commit ff4e046

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

demo/mnist/mnist_provider.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from paddle.trainer.PyDataProvider2 import *
2-
2+
import numpy
33

44
# Define a py data provider
55
@provider(
66
input_types={'pixel': dense_vector(28 * 28),
7-
'label': integer_value(10)})
7+
'label': integer_value(10)},
8+
cache=CacheType.CACHE_PASS_IN_MEM)
89
def process(settings, filename): # settings is not used currently.
910
imgf = filename + "-images-idx3-ubyte"
1011
labelf = filename + "-labels-idx1-ubyte"
@@ -19,13 +20,13 @@ def process(settings, filename): # settings is not used currently.
1920
n = 60000
2021
else:
2122
n = 10000
22-
23-
for i in range(n):
24-
label = ord(l.read(1))
25-
pixels = []
26-
for j in range(28 * 28):
27-
pixels.append(float(ord(f.read(1))) / 255.0)
28-
yield {"pixel": pixels, 'label': label}
29-
23+
24+
images = numpy.fromfile(f, 'ubyte', count=n*28*28).reshape((n, 28*28)).astype('float32')
25+
images = images / 255.0 * 2.0 - 1.0
26+
labels = numpy.fromfile(l, 'ubyte', count=n).astype("int")
27+
28+
for i in xrange(n):
29+
yield {"pixel": images[i, :], 'label': labels[i]}
30+
3031
f.close()
3132
l.close()

0 commit comments

Comments
 (0)