Skip to content

Commit aa2bcf5

Browse files
authored
Merge pull request #1537 from helinwang/batch
move paddle.reader.batch to paddle.batch
2 parents 8bef3f4 + 3432b4c commit aa2bcf5

File tree

5 files changed

+42
-27
lines changed

5 files changed

+42
-27
lines changed

demo/image_classification/api_v2_train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def event_handler(event):
6666
sys.stdout.flush()
6767
if isinstance(event, paddle.event.EndPass):
6868
result = trainer.test(
69-
reader=paddle.reader.batched(
69+
reader=paddle.batch(
7070
paddle.dataset.cifar.test10(), batch_size=128),
7171
reader_dict={'image': 0,
7272
'label': 1})
@@ -77,7 +77,7 @@ def event_handler(event):
7777
parameters=parameters,
7878
update_equation=momentum_optimizer)
7979
trainer.train(
80-
reader=paddle.reader.batched(
80+
reader=paddle.batch(
8181
paddle.reader.shuffle(
8282
paddle.dataset.cifar.train10(), buf_size=50000),
8383
batch_size=128),

demo/mnist/api_train_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def event_handler(event):
9898
result.metrics['classification_error_evaluator']))
9999

100100
trainer.train(
101-
reader=paddle.reader.batched(
101+
reader=paddle.batch(
102102
paddle.reader.shuffle(
103103
paddle.dataset.mnist.train(), buf_size=8192),
104104
batch_size=128),
@@ -115,7 +115,7 @@ def event_handler(event):
115115
probs = paddle.infer(
116116
output=predict,
117117
parameters=parameters,
118-
reader=paddle.reader.batched(
118+
reader=paddle.batch(
119119
paddle.reader.firstn(
120120
paddle.reader.map_readers(lambda item: (item[0], ),
121121
paddle.dataset.mnist.test()),

python/paddle/v2/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import inference
2929
import networks
3030
import py_paddle.swig_paddle as api
31+
import minibatch
3132

3233
__all__ = [
3334
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
@@ -45,3 +46,4 @@ def init(**kwargs):
4546

4647

4748
infer = inference.infer
49+
batch = minibatch.batch

python/paddle/v2/minibatch.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def batch(reader, batch_size):
17+
"""
18+
Create a batch reader.
19+
:param reader: the data reader to read from.
20+
:param batch_size: batch_size
21+
:return: the batch reader.
22+
"""
23+
24+
def batch_reader():
25+
r = reader()
26+
batch = []
27+
for instance in r:
28+
batch.append(instance)
29+
if len(batch) == batch_size:
30+
yield batch
31+
batch = []
32+
if batch:
33+
yield batch
34+
35+
return batch_reader

python/paddle/v2/reader/decorator.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
__all__ = [
1616
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
17-
'ComposeNotAligned', 'batched', 'firstn'
17+
'ComposeNotAligned', 'firstn'
1818
]
1919

2020
import itertools
@@ -193,28 +193,6 @@ def data_reader():
193193
return data_reader
194194

195195

196-
def batched(reader, batch_size):
197-
"""
198-
Create a batched reader.
199-
:param reader: the data reader to read from.
200-
:param batch_size: batch_size
201-
:return: the batched reader.
202-
"""
203-
204-
def batched_reader():
205-
r = reader()
206-
batch = []
207-
for instance in r:
208-
batch.append(instance)
209-
if len(batch) == batch_size:
210-
yield batch
211-
batch = []
212-
if batch:
213-
yield batch
214-
215-
return batched_reader
216-
217-
218196
def firstn(reader, n):
219197
"""
220198
Limit the max number of samples that reader could return.

0 commit comments

Comments
 (0)