Skip to content

Commit 71ab4df

Browse files
committed
Follow comments, remove reader/batch_size in interface.
1 parent 5905d0e commit 71ab4df

File tree

2 files changed

+9
-25
lines changed

2 files changed

+9
-25
lines changed

demo/mnist/api_train_v2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ def event_handler(event):
132132

133133
# output is a softmax layer. It returns probabilities.
134134
# Shape should be (100, 10)
135-
probs = paddle.infer(output=predict, parameters=parameters, input=test_data)
135+
probs = paddle.infer(
136+
output_layer=predict, parameters=parameters, input=test_data)
136137
print probs.shape
137138

138139

python/paddle/v2/inference.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010

1111
class Inference(object):
12-
def __init__(self, output, parameters):
13-
topo = topology.Topology(output)
12+
def __init__(self, output_layer, parameters):
13+
topo = topology.Topology(output_layer)
1414
gm = api.GradientMachine.createFromConfigProto(
1515
topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
1616
for param in gm.getParameters():
@@ -70,13 +70,7 @@ def infer(self, field='value', **kwargs):
7070
return retv
7171

7272

73-
def infer(output,
74-
parameters,
75-
input=None,
76-
batch_size=None,
77-
reader=None,
78-
feeding=None,
79-
field='value'):
73+
def infer(output_layer, parameters, input=None, feeding=None, field='value'):
8074
"""
8175
Infer a neural network by given neural network output and parameters. The
8276
user should pass either a batch of input data or reader method.
@@ -89,19 +83,13 @@ def infer(output,
8983
batch_size=32)
9084
print result
9185
92-
:param output: output of the neural network that would be inferred
93-
:type output: paddle.v2.config_base.Layer
86+
:param output_layer: output of the neural network that would be inferred
87+
:type output_layer: paddle.v2.config_base.Layer
9488
:param parameters: parameters of the neural network.
9589
:type parameters: paddle.v2.parameters.Parameters
9690
:param input: input data batch. Should be a python iterable object, and each
9791
element is the data batch.
9892
:type input: collections.Iterable
99-
:param batch_size: the batch size when perform inference. Default is the
100-
length of input.
101-
:type batch_size: int
102-
:param reader: input data reader creator in batch. If this field is set, the
103-
`input` and `batch_size` will be ignored.
104-
:type reader: callable
10593
:param feeding: Reader dictionary. Default could generate from input
10694
value.
10795
:param field: The prediction field. It should in [`value`, `ids`]. `value`
@@ -112,10 +100,5 @@ def infer(output,
112100
:rtype: numpy.ndarray
113101
"""
114102

115-
inferer = Inference(output=output, parameters=parameters)
116-
return inferer.infer(
117-
field=field,
118-
input=input,
119-
batch_size=batch_size,
120-
reader=reader,
121-
feeding=feeding)
103+
inferer = Inference(output_layer=output_layer, parameters=parameters)
104+
return inferer.infer(field=field, input=input, feeding=feeding)

0 commit comments

Comments
 (0)