@@ -96,6 +96,18 @@ def __iter__(self):
96
96
def __next__ (self ):
97
97
raise NotImplementedError ()
98
98
99
+ @classmethod
100
+ def _check_input_array (cls , item ):
101
+ arr = np .asarray (item )
102
+ if arr .dtype == np .object :
103
+ raise TypeError (
104
+ "\n \t Faild to convert input data to a regular ndarray :\n \t * Usually "
105
+ "this means the input data contains nested lists with different lengths. "
106
+ "\n \t * Check the reader function passed to 'decorate_batch_generator'"
107
+ " to locate the data causes this issue.\n \t * Please consider using "
108
+ "'fluid.create_lod_tensor' to convert it to a LoD-Tensor." )
109
+ return arr
110
+
99
111
100
112
class DataLoader (object ):
101
113
"""
@@ -806,17 +818,6 @@ def __next__(self):
806
818
self ._reset ()
807
819
six .reraise (* sys .exc_info ())
808
820
809
- @classmethod
810
- def _check_input_array (cls , item ):
811
- arr = np .array (item )
812
- if arr .dtype == np .object :
813
- raise TypeError (
814
- "\n \t Faild to convert input data to a regular ndarray :\n \t * Usually "
815
- "this means the input data contains nested lists with different lengths. "
816
- "\n \t * Check the reader function passed to 'decorate_batch_generator'"
817
- " to locate the data causes this issue.\n \t * Please consider using "
818
- "'fluid.create_lod_tensor' to convert it to a LoD-Tensor." )
819
-
820
821
def _exit_thread_expectedly (self ):
821
822
self ._thread_done_event .set ()
822
823
self ._blocking_queue .close ()
@@ -893,7 +894,7 @@ def _reader_thread_loop_for_singleprocess(self):
893
894
array = core .LoDTensorArray ()
894
895
for item in sample :
895
896
if not isinstance (item , core .LoDTensor ):
896
- self ._check_input_array (item )
897
+ item = self ._check_input_array (item )
897
898
tmp = core .LoDTensor ()
898
899
tmp .set (item , core .CPUPlace ())
899
900
item = tmp
@@ -1114,19 +1115,6 @@ def reset(self):
1114
1115
assert not self ._iterable , "reset() cannot be called when DataLoader is iterable"
1115
1116
self ._reset ()
1116
1117
1117
- @classmethod
1118
- def _check_input_array (cls , item ):
1119
- arr = np .array (item )
1120
- if arr .dtype == np .object :
1121
- raise TypeError ((
1122
- "\n \t Faild to convert input data to a regular ndarray :\n \t * Usually "
1123
- "this means the input data contains nested lists with different lengths. "
1124
- "\n \t * Check the reader function passed to 'decorate_batch_generator'"
1125
- " to locate the data causes this issue.\n \t * Please consider using "
1126
- "'fluid.create_lod_tensor' to convert it to a LoD-Tensor." ))
1127
-
1128
- return arr
1129
-
1130
1118
def _start (self ):
1131
1119
def __thread_main__ ():
1132
1120
try :
0 commit comments