@@ -202,6 +202,24 @@ def loop_check(callback, item):
202
202
for each in item :
203
203
callback (each )
204
204
205
+ class CheckInputTypeWrapper (object ):
206
+ def __init__ (self , generator , input_types , logger ):
207
+ self .generator = generator
208
+ self .input_types = input_types
209
+ self .logger = logger
210
+
211
+ def __call__ (self , obj , filename ):
212
+ for items in self .generator (obj , filename ):
213
+ try :
214
+ # dict type is required for input_types when item is dict type
215
+ assert (isinstance (items , dict ) and \
216
+ not isinstance (self .input_types , dict ))== False
217
+ yield items
218
+ except AssertionError as e :
219
+ self .logger .error (
220
+ "%s type is required for input type but got %s" %
221
+ (repr (type (items )), repr (type (self .input_types ))))
222
+ raise
205
223
206
224
def provider (input_types = None ,
207
225
should_shuffle = None ,
@@ -355,6 +373,9 @@ def __init__(self, file_list, **kwargs):
355
373
if use_dynamic_order :
356
374
self .generator = InputOrderWrapper (self .generator ,
357
375
self .input_order )
376
+ else :
377
+ self .generator = CheckInputTypeWrapper (self .generator , self .slots ,
378
+ self .logger )
358
379
if self .check :
359
380
self .generator = CheckWrapper (self .generator , self .slots ,
360
381
check_fail_continue ,
0 commit comments