|
| 1 | +import py_paddle.swig_paddle as api |
| 2 | + |
| 3 | +import topology |
| 4 | +from data_feeder import DataFeeder |
| 5 | +import itertools |
| 6 | +import numpy |
| 7 | + |
| 8 | +__all__ = ['Inference', 'infer'] |
| 9 | + |
| 10 | + |
| 11 | +class Inference(object): |
| 12 | + def __init__(self, output, parameters): |
| 13 | + topo = topology.Topology(output) |
| 14 | + gm = api.GradientMachine.createFromConfigProto( |
| 15 | + topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) |
| 16 | + for param in gm.getParameters(): |
| 17 | + val = param.getBuf(api.PARAMETER_VALUE) |
| 18 | + name = param.getName() |
| 19 | + assert isinstance(val, api.Vector) |
| 20 | + val.copyFromNumpyArray(parameters.get(name).flatten()) |
| 21 | + self.__gradient_machine__ = gm |
| 22 | + self.__data_types__ = topo.data_type() |
| 23 | + |
| 24 | + def iter_infer(self, reader, reader_dict=None): |
| 25 | + if reader_dict is None: |
| 26 | + reader_dict = self.default_reader_dict() |
| 27 | + feeder = DataFeeder(self.__data_types__, reader_dict) |
| 28 | + self.__gradient_machine__.start() |
| 29 | + for data_batch in reader(): |
| 30 | + yield self.__gradient_machine__.forwardTest(feeder(data_batch)) |
| 31 | + self.__gradient_machine__.finish() |
| 32 | + |
| 33 | + def iter_infer_field(self, field, **kwargs): |
| 34 | + for result in self.iter_infer(**kwargs): |
| 35 | + yield [each_result[field] for each_result in result] |
| 36 | + |
| 37 | + def infer(self, field='value', **kwargs): |
| 38 | + retv = None |
| 39 | + for result in self.iter_infer_field(field=field, **kwargs): |
| 40 | + if retv is None: |
| 41 | + retv = [[]] * len(result) |
| 42 | + for i, item in enumerate(result): |
| 43 | + retv[i].append(item) |
| 44 | + retv = [numpy.concatenate(out) for out in retv] |
| 45 | + if len(retv) == 1: |
| 46 | + return retv[0] |
| 47 | + else: |
| 48 | + return retv |
| 49 | + |
| 50 | + def default_reader_dict(self): |
| 51 | + reader_dict = dict() |
| 52 | + for i, tp in enumerate(self.__data_types__): |
| 53 | + reader_dict[tp[0]] = i |
| 54 | + return reader_dict |
| 55 | + |
| 56 | + |
| 57 | +def infer(output, parameters, reader, reader_dict=None, field='value'): |
| 58 | + inferer = Inference(output=output, parameters=parameters) |
| 59 | + return inferer.infer(field=field, reader=reader, reader_dict=reader_dict) |
0 commit comments