Skip to content

Commit d5365bb

Browse files
committed
Add input data interface for inference
1 parent 5f2cbce commit d5365bb

File tree

3 files changed

+99
-16
lines changed

3 files changed

+99
-16
lines changed

demo/mnist/api_train_v2.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def event_handler(event):
9090
print "Pass %d, Batch %d, Cost %f, %s" % (
9191
event.pass_id, event.batch_id, event.cost, event.metrics)
9292
if isinstance(event, paddle.event.EndPass):
93-
result = trainer.test(reader=paddle.reader.batched(
93+
result = trainer.test(reader=paddle.batch(
9494
paddle.dataset.mnist.test(), batch_size=128))
9595
print "Test with Pass %d, Cost %f, %s\n" % (
9696
event.pass_id, result.cost, result.metrics)
@@ -110,17 +110,16 @@ def event_handler(event):
110110
print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1])
111111
print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100)
112112

113+
test_creator = paddle.dataset.mnist.test()
114+
test_data = []
115+
for item in test_creator():
116+
test_data.append(item[0])
117+
if len(test_data) == 100:
118+
break
119+
113120
# output is a softmax layer. It returns probabilities.
114121
# Shape should be (100, 10)
115-
probs = paddle.infer(
116-
output=predict,
117-
parameters=parameters,
118-
reader=paddle.batch(
119-
paddle.reader.firstn(
120-
paddle.reader.map_readers(lambda item: (item[0], ),
121-
paddle.dataset.mnist.test()),
122-
n=100),
123-
batch_size=32))
122+
probs = paddle.infer(output=predict, parameters=parameters, input=test_data)
124123
print probs.shape
125124

126125

doc/api/v2/run_logic.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Trainer API
33
###########
44

5+
56
==========
67
Parameters
78
==========
@@ -24,3 +25,10 @@ Event
2425

2526
.. automodule:: paddle.v2.event
2627
:members:
28+
29+
30+
=========
31+
Inference
32+
=========
33+
34+
.. autofunction:: paddle.v2.infer

python/paddle/v2/inference.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
import numpy
12
import py_paddle.swig_paddle as api
2-
3+
import collections
34
import topology
5+
import minibatch
46
from data_feeder import DataFeeder
5-
import itertools
6-
import numpy
77

88
__all__ = ['infer']
99

@@ -21,9 +21,39 @@ def __init__(self, output, parameters):
2121
self.__gradient_machine__ = gm
2222
self.__data_types__ = topo.data_type()
2323

24-
def iter_infer(self, reader, reader_dict=None):
24+
def iter_infer(self,
25+
input=None,
26+
batch_size=None,
27+
reader=None,
28+
reader_dict=None):
2529
if reader_dict is None:
2630
reader_dict = self.default_reader_dict()
31+
32+
if reader is None:
33+
assert input is not None and isinstance(input, collections.Iterable)
34+
if not isinstance(input, collections.Iterable):
35+
raise TypeError("When reader is None, input should be whole "
36+
"inference data and should be iterable")
37+
38+
if batch_size is None:
39+
if not hasattr(input, '__len__'):
40+
raise ValueError("Should set batch size when input data "
41+
"don't contain length.")
42+
batch_size = len(input)
43+
44+
def __reader_impl__():
45+
for each_sample in input:
46+
if len(reader_dict) == 1:
47+
yield [each_sample]
48+
else:
49+
yield each_sample
50+
51+
reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
52+
else:
53+
if input is not None:
54+
raise ValueError("User should set either input or reader, "
55+
"should not set them both.")
56+
2757
feeder = DataFeeder(self.__data_types__, reader_dict)
2858
self.__gradient_machine__.start()
2959
for data_batch in reader():
@@ -54,6 +84,52 @@ def default_reader_dict(self):
5484
return reader_dict
5585

5686

57-
def infer(output, parameters, reader, reader_dict=None, field='value'):
87+
def infer(output,
88+
parameters,
89+
input=None,
90+
batch_size=None,
91+
reader=None,
92+
reader_dict=None,
93+
field='value'):
94+
"""
95+
Infer a neural network by given neural network output and parameters. The
96+
user should pass either a batch of input data or reader method.
97+
98+
Example usages:
99+
100+
.. code-block:: python
101+
102+
result = paddle.infer(prediction, parameters, input=SomeData,
103+
batch_size=32)
104+
print result
105+
106+
:param output: output of the neural network that would be inferred
107+
:type output: paddle.v2.config_base.Layer
108+
:param parameters: parameters of the neural network.
109+
:type parameters: paddle.v2.parameters.Parameters
110+
:param input: input data batch. Should be a python iterable object, and each
111+
element is the data batch.
112+
:type input: collections.Iterable
113+
:param batch_size: the batch size when perform inference. Default is the
114+
length of input.
115+
:type batch_size: int
116+
:param reader: input data reader creator in batch. If this field is set, the
117+
`input` and `batch_size` will be ignored.
118+
:type reader: callable
119+
:param reader_dict: Reader dictionary. Default could generate from input
120+
value.
121+
:param field: The prediction field. It should in [`value`, `ids`]. `value`
122+
means return the prediction probabilities, `ids` means return
123+
the prediction labels. Default is `value`
124+
:type field: str
125+
:return: a numpy array
126+
:rtype: numpy.ndarray
127+
"""
128+
58129
inferer = Inference(output=output, parameters=parameters)
59-
return inferer.infer(field=field, reader=reader, reader_dict=reader_dict)
130+
return inferer.infer(
131+
field=field,
132+
input=input,
133+
batch_size=batch_size,
134+
reader=reader,
135+
reader_dict=reader_dict)

0 commit comments

Comments
 (0)