@@ -21,30 +21,16 @@ def __init__(self, output_layer, parameters):
21
21
self .__gradient_machine__ = gm
22
22
self .__data_types__ = topo .data_type ()
23
23
24
- def iter_infer (self , input = None , batch_size = None , reader = None ,
25
- feeding = None ):
24
+ def iter_infer (self , input , feeding = None ):
26
25
feeder = DataFeeder (self .__data_types__ , feeding )
27
- if reader is None :
28
- assert input is not None and isinstance (input , collections .Iterable )
29
- if not isinstance (input , collections .Iterable ):
30
- raise TypeError ("When reader is None, input should be whole "
31
- "inference data and should be iterable" )
32
-
33
- if batch_size is None :
34
- if not hasattr (input , '__len__' ):
35
- raise ValueError ("Should set batch size when input data "
36
- "don't contain length." )
37
- batch_size = len (input )
38
-
39
- def __reader_impl__ ():
40
- for each_sample in input :
41
- yield each_sample
42
-
43
- reader = minibatch .batch (__reader_impl__ , batch_size = batch_size )
44
- else :
45
- if input is not None :
46
- raise ValueError ("User should set either input or reader, "
47
- "should not set them both." )
26
+ batch_size = len (input )
27
+
28
+ def __reader_impl__ ():
29
+ for each_sample in input :
30
+ yield each_sample
31
+
32
+ reader = minibatch .batch (__reader_impl__ , batch_size = batch_size )
33
+
48
34
self .__gradient_machine__ .start ()
49
35
for data_batch in reader ():
50
36
yield self .__gradient_machine__ .forwardTest (feeder (data_batch ))
0 commit comments