Skip to content

Commit a503f3c

Browse files
authored
support multi-field for inference
1 parent caffcc8 commit a503f3c

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

python/paddle/v2/inference.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,20 @@ def iter_infer_field(self, field, **kwargs):
4141
yield [each_result[field] for each_result in result]
4242

4343
def infer(self, field='value', **kwargs):
44-
retv = None
45-
for result in self.iter_infer_field(field=field, **kwargs):
46-
if retv is None:
47-
retv = [[]] * len(result)
48-
for i, item in enumerate(result):
49-
retv[i].append(item)
50-
retv = [numpy.concatenate(out) for out in retv]
51-
if len(retv) == 1:
52-
return retv[0]
53-
else:
54-
return retv
44+
if not isinstance(field, list) and not isinstance(field, tuple):
45+
field = [field]
46+
47+
retv_list = []
48+
for each_field in field:
49+
retv = None
50+
for result in self.iter_infer_field(field=each_field, **kwargs):
51+
if retv is None:
52+
retv = [[]] * len(result)
53+
for i, item in enumerate(result):
54+
retv[i].append(item)
55+
retv = [numpy.concatenate(out) for out in retv]
56+
retv_list.append(retv[0] if len(retv) == 1 else retv)
57+
return retv_list[0] if len(retv_list) == 1 else retv_list
5558

5659

5760
def infer(output_layer, parameters, input, feeding=None, field='value'):

0 commit comments

Comments
 (0)