Skip to content

Commit e7ca8b2

Browse files
authored
Merge pull request #1499 from reyoung/feature/inferencer
Complete inferencer
2 parents cdecd53 + 500d883 commit e7ca8b2

File tree

5 files changed

+112
-18
lines changed

5 files changed

+112
-18
lines changed

demo/mnist/api_train_v2.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,19 @@ def event_handler(event):
4444
batch_size=32),
4545
event_handler=event_handler)
4646

47+
# output is a softmax layer. It returns probabilities.
48+
# Shape should be (100, 10)
49+
probs = paddle.infer(
50+
output=inference,
51+
parameters=parameters,
52+
reader=paddle.reader.batched(
53+
paddle.reader.firstn(
54+
paddle.reader.map_readers(lambda item: (item[0], ),
55+
paddle.dataset.mnist.test()),
56+
n=100),
57+
batch_size=32))
58+
print probs.shape
59+
4760

4861
if __name__ == '__main__':
4962
main()

python/paddle/v2/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@
2424
from . import reader
2525
import attr
2626
import pooling
27+
import inferencer
2728
import networks
2829
import py_paddle.swig_paddle as api
2930

3031
__all__ = [
3132
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
3233
'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader',
33-
'topology', 'networks'
34+
'topology', 'networks', 'inferencer', 'infer'
3435
]
3536

3637

@@ -40,3 +41,6 @@ def init(**kwargs):
4041
args.append('--%s=%s' % (key, str(kwargs[key])))
4142

4243
api.initPaddle(*args)
44+
45+
46+
infer = inferencer.infer

python/paddle/v2/dataset/mnist.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,25 @@ def reader():
3535
l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE)
3636
l.stdout.read(8) # skip some magic bytes
3737

38-
while True:
39-
labels = numpy.fromfile(
40-
l.stdout, 'ubyte', count=buffer_size).astype("int")
38+
try: # reader could be break.
39+
while True:
40+
labels = numpy.fromfile(
41+
l.stdout, 'ubyte', count=buffer_size).astype("int")
4142

42-
if labels.size != buffer_size:
43-
break # numpy.fromfile returns empty slice after EOF.
43+
if labels.size != buffer_size:
44+
break # numpy.fromfile returns empty slice after EOF.
4445

45-
images = numpy.fromfile(
46-
m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape(
47-
(buffer_size, 28 * 28)).astype('float32')
46+
images = numpy.fromfile(
47+
m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape(
48+
(buffer_size, 28 * 28)).astype('float32')
4849

49-
images = images / 255.0 * 2.0 - 1.0
50+
images = images / 255.0 * 2.0 - 1.0
5051

51-
for i in xrange(buffer_size):
52-
yield images[i, :], int(labels[i])
53-
54-
m.terminate()
55-
l.terminate()
52+
for i in xrange(buffer_size):
53+
yield images[i, :], int(labels[i])
54+
finally:
55+
m.terminate()
56+
l.terminate()
5657

5758
return reader
5859

python/paddle/v2/inferencer.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import py_paddle.swig_paddle as api
2+
3+
import topology
4+
from data_feeder import DataFeeder
5+
import itertools
6+
import numpy
7+
8+
__all__ = ['Inference', 'infer']
9+
10+
11+
class Inference(object):
12+
def __init__(self, output, parameters):
13+
topo = topology.Topology(output)
14+
gm = api.GradientMachine.createFromConfigProto(
15+
topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
16+
for param in gm.getParameters():
17+
val = param.getBuf(api.PARAMETER_VALUE)
18+
name = param.getName()
19+
assert isinstance(val, api.Vector)
20+
val.copyFromNumpyArray(parameters.get(name).flatten())
21+
self.__gradient_machine__ = gm
22+
self.__data_types__ = topo.data_type()
23+
24+
def iter_infer(self, reader, reader_dict=None):
25+
if reader_dict is None:
26+
reader_dict = self.default_reader_dict()
27+
feeder = DataFeeder(self.__data_types__, reader_dict)
28+
self.__gradient_machine__.start()
29+
for data_batch in reader():
30+
yield self.__gradient_machine__.forwardTest(feeder(data_batch))
31+
self.__gradient_machine__.finish()
32+
33+
def iter_infer_field(self, field, **kwargs):
34+
for result in self.iter_infer(**kwargs):
35+
yield [each_result[field] for each_result in result]
36+
37+
def infer(self, field='value', **kwargs):
38+
retv = None
39+
for result in self.iter_infer_field(field=field, **kwargs):
40+
if retv is None:
41+
retv = [[]] * len(result)
42+
for i, item in enumerate(result):
43+
retv[i].append(item)
44+
retv = [numpy.concatenate(out) for out in retv]
45+
if len(retv) == 1:
46+
return retv[0]
47+
else:
48+
return retv
49+
50+
def default_reader_dict(self):
51+
reader_dict = dict()
52+
for i, tp in enumerate(self.__data_types__):
53+
reader_dict[tp[0]] = i
54+
return reader_dict
55+
56+
57+
def infer(output, parameters, reader, reader_dict=None, field='value'):
58+
inferer = Inference(output=output, parameters=parameters)
59+
return inferer.infer(field=field, reader=reader, reader_dict=reader_dict)

python/paddle/v2/reader/decorator.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414

1515
__all__ = [
1616
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
17-
'ComposeNotAligned', 'batched'
17+
'ComposeNotAligned', 'batched', 'firstn'
1818
]
1919

20-
from Queue import Queue
21-
from threading import Thread
2220
import itertools
2321
import random
22+
from Queue import Queue
23+
from threading import Thread
2424

2525

2626
def map_readers(func, *readers):
@@ -213,3 +213,20 @@ def batched_reader():
213213
yield batch
214214

215215
return batched_reader
216+
217+
218+
def firstn(reader, n):
219+
"""
220+
Limit the max number of samples that reader could return.
221+
"""
222+
223+
# TODO(yuyang18): Check if just drop the reader, could clean the opened
224+
# resource or not?
225+
226+
def firstn_reader():
227+
for i, item in enumerate(reader()):
228+
if i == n:
229+
break
230+
yield item
231+
232+
return firstn_reader

0 commit comments

Comments
 (0)