Skip to content

Commit b993585

Browse files
authored
Merge pull request #711 from Haichao-Zhang/input_types_check
adding input type check for python data provider
2 parents 0948ea3 + 8e9ac0c commit b993585

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

python/paddle/trainer/PyDataProvider2.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,24 @@ def loop_check(callback, item):
202202
for each in item:
203203
callback(each)
204204

205+
class CheckInputTypeWrapper(object):
206+
def __init__(self, generator, input_types, logger):
207+
self.generator = generator
208+
self.input_types = input_types
209+
self.logger = logger
210+
211+
def __call__(self, obj, filename):
212+
for items in self.generator(obj, filename):
213+
try:
214+
# dict type is required for input_types when item is dict type
215+
assert (isinstance(items, dict) and \
216+
not isinstance(self.input_types, dict))==False
217+
yield items
218+
except AssertionError as e:
219+
self.logger.error(
220+
"%s type is required for input type but got %s" %
221+
(repr(type(items)), repr(type(self.input_types))))
222+
raise
205223

206224
def provider(input_types=None,
207225
should_shuffle=None,
@@ -355,6 +373,9 @@ def __init__(self, file_list, **kwargs):
355373
if use_dynamic_order:
356374
self.generator = InputOrderWrapper(self.generator,
357375
self.input_order)
376+
else:
377+
self.generator = CheckInputTypeWrapper(self.generator, self.slots,
378+
self.logger)
358379
if self.check:
359380
self.generator = CheckWrapper(self.generator, self.slots,
360381
check_fail_continue,

0 commit comments

Comments
 (0)