Skip to content

Commit 9ba231d

Browse files
committed
Complete inferencer.
1 parent 4c24ac1 commit 9ba231d

File tree

4 files changed

+61
-28
lines changed

4 files changed

+61
-28
lines changed

demo/mnist/api_train_v2.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,19 @@ def event_handler(event):
4444
batch_size=32),
4545
event_handler=event_handler)
4646

47+
# output is a softmax layer. It returns probabilities.
48+
# Shape should be (100, 10)
49+
probs = paddle.infer(
50+
output=inference,
51+
parameters=parameters,
52+
reader=paddle.reader.batched(
53+
paddle.reader.limited(
54+
paddle.reader.map_readers(lambda item: (item[0], ),
55+
paddle.dataset.mnist.test()),
56+
limit=100),
57+
batch_size=32))
58+
print probs.shape
59+
4760

4861
if __name__ == '__main__':
4962
main()

python/paddle/v2/dataset/mnist.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,25 @@ def reader():
3535
l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE)
3636
l.stdout.read(8) # skip some magic bytes
3737

38-
while True:
39-
labels = numpy.fromfile(
40-
l.stdout, 'ubyte', count=buffer_size).astype("int")
38+
try: # reader could be break.
39+
while True:
40+
labels = numpy.fromfile(
41+
l.stdout, 'ubyte', count=buffer_size).astype("int")
4142

42-
if labels.size != buffer_size:
43-
break # numpy.fromfile returns empty slice after EOF.
43+
if labels.size != buffer_size:
44+
break # numpy.fromfile returns empty slice after EOF.
4445

45-
images = numpy.fromfile(
46-
m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape(
47-
(buffer_size, 28 * 28)).astype('float32')
46+
images = numpy.fromfile(
47+
m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape(
48+
(buffer_size, 28 * 28)).astype('float32')
4849

49-
images = images / 255.0 * 2.0 - 1.0
50+
images = images / 255.0 * 2.0 - 1.0
5051

51-
for i in xrange(buffer_size):
52-
yield images[i, :], int(labels[i])
53-
54-
m.terminate()
55-
l.terminate()
52+
for i in xrange(buffer_size):
53+
yield images[i, :], int(labels[i])
54+
finally:
55+
m.terminate()
56+
l.terminate()
5657

5758
return reader
5859

python/paddle/v2/inferencer.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,36 @@ def __init__(self, output, parameters):
1616
for param in gm.getParameters():
1717
val = param.getBuf(api.PARAMETER_VALUE)
1818
name = param.getName()
19-
assert isinstance(val, api.Matrix)
20-
val.copyFromNumpyMat(parameters.get(name))
19+
assert isinstance(val, api.Vector)
20+
val.copyFromNumpyArray(parameters.get(name).flatten())
2121
self.__gradient_machine__ = gm
2222
self.__data_types__ = topo.data_type()
2323

2424
def iter_infer(self, reader, reader_dict=None):
25+
if reader_dict is None:
26+
reader_dict = self.default_reader_dict()
2527
feeder = DataFeeder(self.__data_types__, reader_dict)
26-
out_args = api.Arguments.createArguments(0)
2728
self.__gradient_machine__.start()
2829
for data_batch in reader():
29-
yield self.__gradient_machine__.forwardTest(
30-
feeder(data_batch), out_args, api.PASS_TEST)
30+
yield self.__gradient_machine__.forwardTest(feeder(data_batch))
3131
self.__gradient_machine__.finish()
3232

3333
def iter_infer_field(self, field, **kwargs):
3434
for result in self.iter_infer(**kwargs):
3535
yield [each_result[field] for each_result in result]
3636

3737
def infer(self, field='value', **kwargs):
38-
retv = []
39-
for result in itertools.izip(
40-
self.iter_infer_field(
41-
field=field, **kwargs)):
42-
retv.append(numpy.concatenate(result))
43-
return retv
38+
retv = None
39+
for result in self.iter_infer_field(field=field, **kwargs):
40+
if retv is None:
41+
retv = [[]] * len(result)
42+
for i, item in enumerate(result):
43+
retv[i].append(item)
44+
retv = [numpy.concatenate(out) for out in retv]
45+
if len(retv) == 1:
46+
return retv[0]
47+
else:
48+
return retv
4449

4550
def default_reader_dict(self):
4651
reader_dict = dict()

python/paddle/v2/reader/decorator.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414

1515
__all__ = [
1616
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
17-
'ComposeNotAligned', 'batched'
17+
'ComposeNotAligned', 'batched', 'limited'
1818
]
1919

20-
from Queue import Queue
21-
from threading import Thread
2220
import itertools
2321
import random
22+
from Queue import Queue
23+
from threading import Thread
2424

2525

2626
def map_readers(func, *readers):
@@ -213,3 +213,17 @@ def batched_reader():
213213
yield batch
214214

215215
return batched_reader
216+
217+
218+
def limited(reader, limit):
219+
"""
220+
Limit the max number of samples that reader could return.
221+
"""
222+
223+
def limited_reader():
224+
for i, item in enumerate(reader()):
225+
if i == limit:
226+
break
227+
yield item
228+
229+
return limited_reader

0 commit comments

Comments
 (0)