Skip to content

Commit 5246f5d

Browse files
authored
Speedup reader check_input_array when item is array (#25395) (#25848)
1 parent 45fa686 commit 5246f5d

File tree

1 file changed

+13
-25
lines changed

1 file changed

+13
-25
lines changed

python/paddle/fluid/reader.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,18 @@ def __iter__(self):
9696
def __next__(self):
9797
raise NotImplementedError()
9898

99+
@classmethod
100+
def _check_input_array(cls, item):
101+
arr = np.asarray(item)
102+
if arr.dtype == np.object:
103+
raise TypeError(
104+
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
105+
"this means the input data contains nested lists with different lengths. "
106+
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
107+
" to locate the data causes this issue.\n\t* Please consider using "
108+
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor.")
109+
return arr
110+
99111

100112
class DataLoader(object):
101113
"""
@@ -806,17 +818,6 @@ def __next__(self):
806818
self._reset()
807819
six.reraise(*sys.exc_info())
808820

809-
@classmethod
810-
def _check_input_array(cls, item):
811-
arr = np.array(item)
812-
if arr.dtype == np.object:
813-
raise TypeError(
814-
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
815-
"this means the input data contains nested lists with different lengths. "
816-
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
817-
" to locate the data causes this issue.\n\t* Please consider using "
818-
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor.")
819-
820821
def _exit_thread_expectedly(self):
821822
self._thread_done_event.set()
822823
self._blocking_queue.close()
@@ -893,7 +894,7 @@ def _reader_thread_loop_for_singleprocess(self):
893894
array = core.LoDTensorArray()
894895
for item in sample:
895896
if not isinstance(item, core.LoDTensor):
896-
self._check_input_array(item)
897+
item = self._check_input_array(item)
897898
tmp = core.LoDTensor()
898899
tmp.set(item, core.CPUPlace())
899900
item = tmp
@@ -1114,19 +1115,6 @@ def reset(self):
11141115
assert not self._iterable, "reset() cannot be called when DataLoader is iterable"
11151116
self._reset()
11161117

1117-
@classmethod
1118-
def _check_input_array(cls, item):
1119-
arr = np.array(item)
1120-
if arr.dtype == np.object:
1121-
raise TypeError((
1122-
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
1123-
"this means the input data contains nested lists with different lengths. "
1124-
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
1125-
" to locate the data causes this issue.\n\t* Please consider using "
1126-
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor."))
1127-
1128-
return arr
1129-
11301118
def _start(self):
11311119
def __thread_main__():
11321120
try:

0 commit comments

Comments
 (0)