9
9
10
10
11
11
class Inference (object ):
12
- def __init__ (self , output , parameters ):
13
- topo = topology .Topology (output )
12
+ def __init__ (self , output_layer , parameters ):
13
+ topo = topology .Topology (output_layer )
14
14
gm = api .GradientMachine .createFromConfigProto (
15
15
topo .proto (), api .CREATE_MODE_TESTING , [api .PARAMETER_VALUE ])
16
16
for param in gm .getParameters ():
@@ -21,33 +21,16 @@ def __init__(self, output, parameters):
21
21
self .__gradient_machine__ = gm
22
22
self .__data_types__ = topo .data_type ()
23
23
24
- def iter_infer (self , input = None , batch_size = None , reader = None ,
25
- feeding = None ):
24
+ def iter_infer (self , input , feeding = None ):
26
25
feeder = DataFeeder (self .__data_types__ , feeding )
27
- if reader is None :
28
- assert input is not None and isinstance (input , collections .Iterable )
29
- if not isinstance (input , collections .Iterable ):
30
- raise TypeError ("When reader is None, input should be whole "
31
- "inference data and should be iterable" )
32
-
33
- if batch_size is None :
34
- if not hasattr (input , '__len__' ):
35
- raise ValueError ("Should set batch size when input data "
36
- "don't contain length." )
37
- batch_size = len (input )
38
-
39
- def __reader_impl__ ():
40
- for each_sample in input :
41
- if len (feeder ) == 1 :
42
- yield [each_sample ]
43
- else :
44
- yield each_sample
45
-
46
- reader = minibatch .batch (__reader_impl__ , batch_size = batch_size )
47
- else :
48
- if input is not None :
49
- raise ValueError ("User should set either input or reader, "
50
- "should not set them both." )
26
+ batch_size = len (input )
27
+
28
+ def __reader_impl__ ():
29
+ for each_sample in input :
30
+ yield each_sample
31
+
32
+ reader = minibatch .batch (__reader_impl__ , batch_size = batch_size )
33
+
51
34
self .__gradient_machine__ .start ()
52
35
for data_batch in reader ():
53
36
yield self .__gradient_machine__ .forwardTest (feeder (data_batch ))
@@ -71,13 +54,7 @@ def infer(self, field='value', **kwargs):
71
54
return retv
72
55
73
56
74
- def infer (output ,
75
- parameters ,
76
- input = None ,
77
- batch_size = None ,
78
- reader = None ,
79
- feeding = None ,
80
- field = 'value' ):
57
+ def infer (output_layer , parameters , input , feeding = None , field = 'value' ):
81
58
"""
82
59
Infer a neural network by given neural network output and parameters. The
83
60
user should pass either a batch of input data or reader method.
@@ -90,19 +67,13 @@ def infer(output,
90
67
batch_size=32)
91
68
print result
92
69
93
- :param output : output of the neural network that would be inferred
94
- :type output : paddle.v2.config_base.Layer
70
+ :param output_layer : output of the neural network that would be inferred
71
+ :type output_layer : paddle.v2.config_base.Layer
95
72
:param parameters: parameters of the neural network.
96
73
:type parameters: paddle.v2.parameters.Parameters
97
74
:param input: input data batch. Should be a python iterable object, and each
98
75
element is the data batch.
99
76
:type input: collections.Iterable
100
- :param batch_size: the batch size when perform inference. Default is the
101
- length of input.
102
- :type batch_size: int
103
- :param reader: input data reader creator in batch. If this field is set, the
104
- `input` and `batch_size` will be ignored.
105
- :type reader: callable
106
77
:param feeding: Reader dictionary. Default could generate from input
107
78
value.
108
79
:param field: The prediction field. It should in [`value`, `ids`]. `value`
@@ -113,10 +84,5 @@ def infer(output,
113
84
:rtype: numpy.ndarray
114
85
"""
115
86
116
- inferer = Inference (output = output , parameters = parameters )
117
- return inferer .infer (
118
- field = field ,
119
- input = input ,
120
- batch_size = batch_size ,
121
- reader = reader ,
122
- feeding = feeding )
87
+ inferer = Inference (output_layer = output_layer , parameters = parameters )
88
+ return inferer .infer (field = field , input = input , feeding = feeding )
0 commit comments