1
+ import numpy
1
2
import py_paddle .swig_paddle as api
2
-
3
+ import collections
3
4
import topology
5
+ import minibatch
4
6
from data_feeder import DataFeeder
5
- import itertools
6
- import numpy
7
7
8
8
__all__ = ['infer' ]
9
9
@@ -21,8 +21,33 @@ def __init__(self, output, parameters):
21
21
self .__gradient_machine__ = gm
22
22
self .__data_types__ = topo .data_type ()
23
23
24
- def iter_infer (self , reader , feeding = None ):
24
+ def iter_infer (self , input = None , batch_size = None , reader = None ,
25
+ feeding = None ):
25
26
feeder = DataFeeder (self .__data_types__ , feeding )
27
+ if reader is None :
28
+ assert input is not None and isinstance (input , collections .Iterable )
29
+ if not isinstance (input , collections .Iterable ):
30
+ raise TypeError ("When reader is None, input should be whole "
31
+ "inference data and should be iterable" )
32
+
33
+ if batch_size is None :
34
+ if not hasattr (input , '__len__' ):
35
+ raise ValueError ("Should set batch size when input data "
36
+ "don't contain length." )
37
+ batch_size = len (input )
38
+
39
+ def __reader_impl__ ():
40
+ for each_sample in input :
41
+ if len (feeder ) == 1 :
42
+ yield [each_sample ]
43
+ else :
44
+ yield each_sample
45
+
46
+ reader = minibatch .batch (__reader_impl__ , batch_size = batch_size )
47
+ else :
48
+ if input is not None :
49
+ raise ValueError ("User should set either input or reader, "
50
+ "should not set them both." )
26
51
self .__gradient_machine__ .start ()
27
52
for data_batch in reader ():
28
53
yield self .__gradient_machine__ .forwardTest (feeder (data_batch ))
@@ -46,6 +71,52 @@ def infer(self, field='value', **kwargs):
46
71
return retv
47
72
48
73
49
- def infer (output , parameters , reader , feeding = None , field = 'value' ):
74
+ def infer (output ,
75
+ parameters ,
76
+ input = None ,
77
+ batch_size = None ,
78
+ reader = None ,
79
+ feeding = None ,
80
+ field = 'value' ):
81
+ """
82
+ Infer a neural network by given neural network output and parameters. The
83
+ user should pass either a batch of input data or reader method.
84
+
85
+ Example usages:
86
+
87
+ .. code-block:: python
88
+
89
+ result = paddle.infer(prediction, parameters, input=SomeData,
90
+ batch_size=32)
91
+ print result
92
+
93
+ :param output: output of the neural network that would be inferred
94
+ :type output: paddle.v2.config_base.Layer
95
+ :param parameters: parameters of the neural network.
96
+ :type parameters: paddle.v2.parameters.Parameters
97
+ :param input: input data batch. Should be a python iterable object, and each
98
+ element is the data batch.
99
+ :type input: collections.Iterable
100
+ :param batch_size: the batch size when perform inference. Default is the
101
+ length of input.
102
+ :type batch_size: int
103
+ :param reader: input data reader creator in batch. If this field is set, the
104
+ `input` and `batch_size` will be ignored.
105
+ :type reader: callable
106
+ :param feeding: Reader dictionary. Default could generate from input
107
+ value.
108
+ :param field: The prediction field. It should in [`value`, `ids`]. `value`
109
+ means return the prediction probabilities, `ids` means return
110
+ the prediction labels. Default is `value`
111
+ :type field: str
112
+ :return: a numpy array
113
+ :rtype: numpy.ndarray
114
+ """
115
+
50
116
inferer = Inference (output = output , parameters = parameters )
51
- return inferer .infer (field = field , reader = reader , feeding = feeding )
117
+ return inferer .infer (
118
+ field = field ,
119
+ input = input ,
120
+ batch_size = batch_size ,
121
+ reader = reader ,
122
+ feeding = feeding )
0 commit comments