Skip to content

Commit 4c24ac1

Browse files
committed
Init inferencer.
1 parent 91f13e4 commit 4c24ac1

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

python/paddle/v2/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
from . import reader
2525
import attr
2626
import pooling
27+
import inferencer
2728
import py_paddle.swig_paddle as api
2829

2930
__all__ = [
3031
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
3132
'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader',
32-
'topology'
33+
'topology', 'inferencer', 'infer'
3334
]
3435

3536

@@ -39,3 +40,6 @@ def init(**kwargs):
3940
args.append('--%s=%s' % (key, str(kwargs[key])))
4041

4142
api.initPaddle(*args)
43+
44+
45+
infer = inferencer.infer

python/paddle/v2/inferencer.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import py_paddle.swig_paddle as api
2+
3+
import topology
4+
from data_feeder import DataFeeder
5+
import itertools
6+
import numpy
7+
8+
__all__ = ['InferenceEngine', 'infer']
9+
10+
11+
class InferenceEngine(object):
12+
def __init__(self, output, parameters):
13+
topo = topology.Topology(output)
14+
gm = api.GradientMachine.createFromConfigProto(
15+
topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
16+
for param in gm.getParameters():
17+
val = param.getBuf(api.PARAMETER_VALUE)
18+
name = param.getName()
19+
assert isinstance(val, api.Matrix)
20+
val.copyFromNumpyMat(parameters.get(name))
21+
self.__gradient_machine__ = gm
22+
self.__data_types__ = topo.data_type()
23+
24+
def iter_infer(self, reader, reader_dict=None):
25+
feeder = DataFeeder(self.__data_types__, reader_dict)
26+
out_args = api.Arguments.createArguments(0)
27+
self.__gradient_machine__.start()
28+
for data_batch in reader():
29+
yield self.__gradient_machine__.forwardTest(
30+
feeder(data_batch), out_args, api.PASS_TEST)
31+
self.__gradient_machine__.finish()
32+
33+
def iter_infer_field(self, field, **kwargs):
34+
for result in self.iter_infer(**kwargs):
35+
yield [each_result[field] for each_result in result]
36+
37+
def infer(self, field='value', **kwargs):
38+
retv = []
39+
for result in itertools.izip(
40+
self.iter_infer_field(
41+
field=field, **kwargs)):
42+
retv.append(numpy.concatenate(result))
43+
return retv
44+
45+
def default_reader_dict(self):
46+
reader_dict = dict()
47+
for i, tp in enumerate(self.__data_types__):
48+
reader_dict[tp[0]] = i
49+
return reader_dict
50+
51+
52+
def infer(output, parameters, reader, reader_dict=None, field='value'):
53+
inferer = InferenceEngine(output=output, parameters=parameters)
54+
return inferer.infer(field=field, reader=reader, reader_dict=reader_dict)

0 commit comments

Comments
 (0)