9
9
10
10
11
11
class Inference (object ):
12
- def __init__ (self , output , parameters ):
13
- topo = topology .Topology (output )
12
+ def __init__ (self , output_layer , parameters ):
13
+ topo = topology .Topology (output_layer )
14
14
gm = api .GradientMachine .createFromConfigProto (
15
15
topo .proto (), api .CREATE_MODE_TESTING , [api .PARAMETER_VALUE ])
16
16
for param in gm .getParameters ():
@@ -70,13 +70,7 @@ def infer(self, field='value', **kwargs):
70
70
return retv
71
71
72
72
73
- def infer (output ,
74
- parameters ,
75
- input = None ,
76
- batch_size = None ,
77
- reader = None ,
78
- feeding = None ,
79
- field = 'value' ):
73
+ def infer (output_layer , parameters , input = None , feeding = None , field = 'value' ):
80
74
"""
81
75
Infer a neural network by given neural network output and parameters. The
82
76
user should pass either a batch of input data or reader method.
@@ -89,19 +83,13 @@ def infer(output,
89
83
batch_size=32)
90
84
print result
91
85
92
- :param output : output of the neural network that would be inferred
93
- :type output : paddle.v2.config_base.Layer
86
+ :param output_layer : output of the neural network that would be inferred
87
+ :type output_layer : paddle.v2.config_base.Layer
94
88
:param parameters: parameters of the neural network.
95
89
:type parameters: paddle.v2.parameters.Parameters
96
90
:param input: input data batch. Should be a python iterable object, and each
97
91
element is the data batch.
98
92
:type input: collections.Iterable
99
- :param batch_size: the batch size when perform inference. Default is the
100
- length of input.
101
- :type batch_size: int
102
- :param reader: input data reader creator in batch. If this field is set, the
103
- `input` and `batch_size` will be ignored.
104
- :type reader: callable
105
93
:param feeding: Reader dictionary. Default could generate from input
106
94
value.
107
95
:param field: The prediction field. It should in [`value`, `ids`]. `value`
@@ -112,10 +100,5 @@ def infer(output,
112
100
:rtype: numpy.ndarray
113
101
"""
114
102
115
- inferer = Inference (output = output , parameters = parameters )
116
- return inferer .infer (
117
- field = field ,
118
- input = input ,
119
- batch_size = batch_size ,
120
- reader = reader ,
121
- feeding = feeding )
103
+ inferer = Inference (output_layer = output_layer , parameters = parameters )
104
+ return inferer .infer (field = field , input = input , feeding = feeding )
0 commit comments