@@ -38,23 +38,26 @@ def __reader_impl__():
38
38
39
39
def iter_infer_field (self , field , ** kwargs ):
40
40
for result in self .iter_infer (** kwargs ):
41
- yield [each_result [field ] for each_result in result ]
41
+ yield [
42
+ each_result [each_field ]
43
+ for each_result in result for each_field in field
44
+ ]
42
45
43
46
def infer (self , field = 'value' , ** kwargs ):
44
47
if not isinstance (field , list ) and not isinstance (field , tuple ):
45
48
field = [field ]
46
49
47
- retv_list = []
48
- for each_field in field :
49
- retv = None
50
- for result in self . iter_infer_field ( field = each_field , ** kwargs ):
51
- if retv is None :
52
- retv = [[]] * len ( result )
53
- for i , item in enumerate ( result ):
54
- retv [ i ]. append ( item )
55
- retv = [ numpy . concatenate ( out ) for out in retv ]
56
- retv_list . append ( retv [ 0 ] if len ( retv ) == 1 else retv )
57
- return retv_list [ 0 ] if len ( retv_list ) == 1 else retv_list
50
+ retv = None
51
+ for result in self . iter_infer_field ( field = field , ** kwargs ) :
52
+ if retv is None :
53
+ retv = [[]] * len ( result )
54
+ for i , item in enumerate ( result ) :
55
+ retv [ i ]. append ( item )
56
+ retv = [ numpy . concatenate ( out ) for out in retv ]
57
+ if len ( retv ) == 1 :
58
+ return retv [ 0 ]
59
+ else :
60
+ return retv
58
61
59
62
60
63
def infer (output_layer , parameters , input , feeding = None , field = 'value' ):
0 commit comments