Skip to content

Commit 4feb501

Browse files
authored
Merge pull request #1561 from reyoung/feature/better_infer_interface
Add input data interface for inference
2 parents 7e98163 + 05b45e1 commit 4feb501

File tree

2 files changed

+19
-52
lines changed

2 files changed

+19
-52
lines changed

demo/mnist/api_train_v2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,14 @@ def event_handler(event):
122122
test_creator = paddle.dataset.mnist.test()
123123
test_data = []
124124
for item in test_creator():
125-
test_data.append(item[0])
125+
test_data.append((item[0], ))
126126
if len(test_data) == 100:
127127
break
128128

129129
# output is a softmax layer. It returns probabilities.
130130
# Shape should be (100, 10)
131-
probs = paddle.infer(output=predict, parameters=parameters, input=test_data)
131+
probs = paddle.infer(
132+
output_layer=predict, parameters=parameters, input=test_data)
132133
print probs.shape
133134

134135

python/paddle/v2/inference.py

Lines changed: 16 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010

1111
class Inference(object):
12-
def __init__(self, output, parameters):
13-
topo = topology.Topology(output)
12+
def __init__(self, output_layer, parameters):
13+
topo = topology.Topology(output_layer)
1414
gm = api.GradientMachine.createFromConfigProto(
1515
topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
1616
for param in gm.getParameters():
@@ -21,33 +21,16 @@ def __init__(self, output, parameters):
2121
self.__gradient_machine__ = gm
2222
self.__data_types__ = topo.data_type()
2323

24-
def iter_infer(self, input=None, batch_size=None, reader=None,
25-
feeding=None):
24+
def iter_infer(self, input, feeding=None):
2625
feeder = DataFeeder(self.__data_types__, feeding)
27-
if reader is None:
28-
assert input is not None and isinstance(input, collections.Iterable)
29-
if not isinstance(input, collections.Iterable):
30-
raise TypeError("When reader is None, input should be whole "
31-
"inference data and should be iterable")
32-
33-
if batch_size is None:
34-
if not hasattr(input, '__len__'):
35-
raise ValueError("Should set batch size when input data "
36-
"don't contain length.")
37-
batch_size = len(input)
38-
39-
def __reader_impl__():
40-
for each_sample in input:
41-
if len(feeder) == 1:
42-
yield [each_sample]
43-
else:
44-
yield each_sample
45-
46-
reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
47-
else:
48-
if input is not None:
49-
raise ValueError("User should set either input or reader, "
50-
"should not set them both.")
26+
batch_size = len(input)
27+
28+
def __reader_impl__():
29+
for each_sample in input:
30+
yield each_sample
31+
32+
reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
33+
5134
self.__gradient_machine__.start()
5235
for data_batch in reader():
5336
yield self.__gradient_machine__.forwardTest(feeder(data_batch))
@@ -71,13 +54,7 @@ def infer(self, field='value', **kwargs):
7154
return retv
7255

7356

74-
def infer(output,
75-
parameters,
76-
input=None,
77-
batch_size=None,
78-
reader=None,
79-
feeding=None,
80-
field='value'):
57+
def infer(output_layer, parameters, input, feeding=None, field='value'):
8158
"""
8259
Infer a neural network by given neural network output and parameters. The
8360
user should pass either a batch of input data or reader method.
@@ -90,19 +67,13 @@ def infer(output,
9067
batch_size=32)
9168
print result
9269
93-
:param output: output of the neural network that would be inferred
94-
:type output: paddle.v2.config_base.Layer
70+
:param output_layer: output of the neural network that would be inferred
71+
:type output_layer: paddle.v2.config_base.Layer
9572
:param parameters: parameters of the neural network.
9673
:type parameters: paddle.v2.parameters.Parameters
9774
:param input: input data batch. Should be a python iterable object, and each
9875
element is the data batch.
9976
:type input: collections.Iterable
100-
:param batch_size: the batch size when perform inference. Default is the
101-
length of input.
102-
:type batch_size: int
103-
:param reader: input data reader creator in batch. If this field is set, the
104-
`input` and `batch_size` will be ignored.
105-
:type reader: callable
10677
:param feeding: Reader dictionary. Default could generate from input
10778
value.
10879
:param field: The prediction field. It should in [`value`, `ids`]. `value`
@@ -113,10 +84,5 @@ def infer(output,
11384
:rtype: numpy.ndarray
11485
"""
11586

116-
inferer = Inference(output=output, parameters=parameters)
117-
return inferer.infer(
118-
field=field,
119-
input=input,
120-
batch_size=batch_size,
121-
reader=reader,
122-
feeding=feeding)
87+
inferer = Inference(output_layer=output_layer, parameters=parameters)
88+
return inferer.infer(field=field, input=input, feeding=feeding)

0 commit comments

Comments
 (0)