Skip to content

Commit 52dc6a9

Browse files
committed
Merge branch 'feature/better_infer_interface' into feature/recommendation_v2_api
2 parents c7d259e + d5365bb commit 52dc6a9

File tree

4 files changed

+98
-21
lines changed

4 files changed

+98
-21
lines changed

demo/mnist/api_train_v2.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,8 @@ def main():
9292
def event_handler(event):
9393
if isinstance(event, paddle.event.EndIteration):
9494
if event.batch_id % 1000 == 0:
95-
result = trainer.test(reader=paddle.batch(
96-
paddle.dataset.mnist.test(), batch_size=256))
97-
98-
print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
99-
event.pass_id, event.batch_id, event.cost, event.metrics,
100-
result.metrics)
95+
print "Pass %d, Batch %d, Cost %f, %s" % (
96+
event.pass_id, event.batch_id, event.cost, event.metrics)
10197

10298
with gzip.open('params.tar.gz', 'w') as f:
10399
parameters.to_tar(f)
@@ -123,17 +119,16 @@ def event_handler(event):
123119
print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1])
124120
print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100)
125121

122+
test_creator = paddle.dataset.mnist.test()
123+
test_data = []
124+
for item in test_creator():
125+
test_data.append(item[0])
126+
if len(test_data) == 100:
127+
break
128+
126129
# output is a softmax layer. It returns probabilities.
127130
# Shape should be (100, 10)
128-
probs = paddle.infer(
129-
output=predict,
130-
parameters=parameters,
131-
reader=paddle.batch(
132-
paddle.reader.firstn(
133-
paddle.reader.map_readers(lambda item: (item[0], ),
134-
paddle.dataset.mnist.test()),
135-
n=100),
136-
batch_size=32))
131+
probs = paddle.infer(output=predict, parameters=parameters, input=test_data)
137132
print probs.shape
138133

139134

doc/api/v2/run_logic.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Trainer API
33
###########
44

5+
56
==========
67
Parameters
78
==========
@@ -24,3 +25,10 @@ Event
2425

2526
.. automodule:: paddle.v2.event
2627
:members:
28+
29+
30+
=========
31+
Inference
32+
=========
33+
34+
.. autofunction:: paddle.v2.infer

python/paddle/v2/data_feeder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def __init__(self, data_types, feeding=None):
8585
input_types.append(each[1])
8686
DataProviderConverter.__init__(self, input_types)
8787

88+
def __len__(self):
89+
return len(self.input_names)
90+
8891
def convert(self, dat, argument=None):
8992
"""
9093
:param dat: A list of mini-batch data. Each sample is a list or tuple

python/paddle/v2/inference.py

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
import numpy
12
import py_paddle.swig_paddle as api
2-
3+
import collections
34
import topology
5+
import minibatch
46
from data_feeder import DataFeeder
5-
import itertools
6-
import numpy
77

88
__all__ = ['infer']
99

@@ -21,8 +21,33 @@ def __init__(self, output, parameters):
2121
self.__gradient_machine__ = gm
2222
self.__data_types__ = topo.data_type()
2323

24-
def iter_infer(self, reader, feeding=None):
24+
def iter_infer(self, input=None, batch_size=None, reader=None,
25+
feeding=None):
2526
feeder = DataFeeder(self.__data_types__, feeding)
27+
if reader is None:
28+
assert input is not None and isinstance(input, collections.Iterable)
29+
if not isinstance(input, collections.Iterable):
30+
raise TypeError("When reader is None, input should be whole "
31+
"inference data and should be iterable")
32+
33+
if batch_size is None:
34+
if not hasattr(input, '__len__'):
35+
raise ValueError("Should set batch size when input data "
36+
"don't contain length.")
37+
batch_size = len(input)
38+
39+
def __reader_impl__():
40+
for each_sample in input:
41+
if len(feeder) == 1:
42+
yield [each_sample]
43+
else:
44+
yield each_sample
45+
46+
reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
47+
else:
48+
if input is not None:
49+
raise ValueError("User should set either input or reader, "
50+
"should not set them both.")
2651
self.__gradient_machine__.start()
2752
for data_batch in reader():
2853
yield self.__gradient_machine__.forwardTest(feeder(data_batch))
@@ -46,6 +71,52 @@ def infer(self, field='value', **kwargs):
4671
return retv
4772

4873

49-
def infer(output, parameters, reader, feeding=None, field='value'):
74+
def infer(output,
75+
parameters,
76+
input=None,
77+
batch_size=None,
78+
reader=None,
79+
feeding=None,
80+
field='value'):
81+
"""
82+
Infer a neural network by given neural network output and parameters. The
83+
user should pass either a batch of input data or reader method.
84+
85+
Example usages:
86+
87+
.. code-block:: python
88+
89+
result = paddle.infer(prediction, parameters, input=SomeData,
90+
batch_size=32)
91+
print result
92+
93+
:param output: output of the neural network that would be inferred
94+
:type output: paddle.v2.config_base.Layer
95+
:param parameters: parameters of the neural network.
96+
:type parameters: paddle.v2.parameters.Parameters
97+
:param input: input data batch. Should be a python iterable object, and each
98+
element is the data batch.
99+
:type input: collections.Iterable
100+
:param batch_size: the batch size when perform inference. Default is the
101+
length of input.
102+
:type batch_size: int
103+
:param reader: input data reader creator in batch. If this field is set, the
104+
`input` and `batch_size` will be ignored.
105+
:type reader: callable
106+
:param feeding: Reader dictionary. Default could generate from input
107+
value.
108+
:param field: The prediction field. It should in [`value`, `ids`]. `value`
109+
means return the prediction probabilities, `ids` means return
110+
the prediction labels. Default is `value`
111+
:type field: str
112+
:return: a numpy array
113+
:rtype: numpy.ndarray
114+
"""
115+
50116
inferer = Inference(output=output, parameters=parameters)
51-
return inferer.infer(field=field, reader=reader, feeding=feeding)
117+
return inferer.infer(
118+
field=field,
119+
input=input,
120+
batch_size=batch_size,
121+
reader=reader,
122+
feeding=feeding)

0 commit comments

Comments
 (0)