Skip to content

Commit ac843bb

Browse files
authored
Update with comments
1 parent a503f3c commit ac843bb

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

python/paddle/v2/inference.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,23 +38,26 @@ def __reader_impl__():
3838

3939
def iter_infer_field(self, field, **kwargs):
4040
for result in self.iter_infer(**kwargs):
41-
yield [each_result[field] for each_result in result]
41+
yield [
42+
each_result[each_field]
43+
for each_result in result for each_field in field
44+
]
4245

4346
def infer(self, field='value', **kwargs):
4447
if not isinstance(field, list) and not isinstance(field, tuple):
4548
field = [field]
4649

47-
retv_list = []
48-
for each_field in field:
49-
retv = None
50-
for result in self.iter_infer_field(field=each_field, **kwargs):
51-
if retv is None:
52-
retv = [[]] * len(result)
53-
for i, item in enumerate(result):
54-
retv[i].append(item)
55-
retv = [numpy.concatenate(out) for out in retv]
56-
retv_list.append(retv[0] if len(retv) == 1 else retv)
57-
return retv_list[0] if len(retv_list) == 1 else retv_list
50+
retv = None
51+
for result in self.iter_infer_field(field=field, **kwargs):
52+
if retv is None:
53+
retv = [[]] * len(result)
54+
for i, item in enumerate(result):
55+
retv[i].append(item)
56+
retv = [numpy.concatenate(out) for out in retv]
57+
if len(retv) == 1:
58+
return retv[0]
59+
else:
60+
return retv
5861

5962

6063
def infer(output_layer, parameters, input, feeding=None, field='value'):

0 commit comments

Comments
 (0)