Skip to content

Commit 05b45e1

Browse files
committed
Remove reader logic
1 parent 797e89e commit 05b45e1

File tree

1 file changed

+9
-23
lines changed

1 file changed

+9
-23
lines changed

python/paddle/v2/inference.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,30 +21,16 @@ def __init__(self, output_layer, parameters):
2121
self.__gradient_machine__ = gm
2222
self.__data_types__ = topo.data_type()
2323

24-
def iter_infer(self, input=None, batch_size=None, reader=None,
25-
feeding=None):
24+
def iter_infer(self, input, feeding=None):
2625
feeder = DataFeeder(self.__data_types__, feeding)
27-
if reader is None:
28-
assert input is not None and isinstance(input, collections.Iterable)
29-
if not isinstance(input, collections.Iterable):
30-
raise TypeError("When reader is None, input should be whole "
31-
"inference data and should be iterable")
32-
33-
if batch_size is None:
34-
if not hasattr(input, '__len__'):
35-
raise ValueError("Should set batch size when input data "
36-
"don't contain length.")
37-
batch_size = len(input)
38-
39-
def __reader_impl__():
40-
for each_sample in input:
41-
yield each_sample
42-
43-
reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
44-
else:
45-
if input is not None:
46-
raise ValueError("User should set either input or reader, "
47-
"should not set them both.")
26+
batch_size = len(input)
27+
28+
def __reader_impl__():
29+
for each_sample in input:
30+
yield each_sample
31+
32+
reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
33+
4834
self.__gradient_machine__.start()
4935
for data_batch in reader():
5036
yield self.__gradient_machine__.forwardTest(feeder(data_batch))

0 commit comments

Comments
 (0)