Skip to content

Commit 6915d51

Browse files
author
chengduo
authored
Merge pull request #11062 from chengduoZH/refine_batch_py
Drop the last batch, if the size of last batch is not equal to batch_size.
2 parents 0c0c5df + 164692d commit 6915d51

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

python/paddle/batch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515
__all__ = ['batch']
1616

1717

18-
def batch(reader, batch_size):
18+
def batch(reader, batch_size, drop_last=False):
1919
"""
2020
Create a batched reader.
2121
2222
:param reader: the data reader to read from.
2323
:type reader: callable
2424
:param batch_size: size of each mini-batch
2525
:type batch_size: int
26+
:param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
27+
:type drop_last: bool
2628
:return: the batched reader.
2729
:rtype: callable
2830
"""
@@ -35,7 +37,7 @@ def batch_reader():
3537
if len(b) == batch_size:
3638
yield b
3739
b = []
38-
if b:
40+
if drop_last == False and len(b) != 0:
3941
yield b
4042

4143
return batch_reader

python/paddle/v2/minibatch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515
__all__ = ['batch']
1616

1717

18-
def batch(reader, batch_size):
18+
def batch(reader, batch_size, drop_last=False):
1919
"""
2020
Create a batched reader.
2121
2222
:param reader: the data reader to read from.
2323
:type reader: callable
2424
:param batch_size: size of each mini-batch
2525
:type batch_size: int
26+
:param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
27+
:type drop_last: bool
2628
:return: the batched reader.
2729
:rtype: callable
2830
"""
@@ -35,7 +37,7 @@ def batch_reader():
3537
if len(b) == batch_size:
3638
yield b
3739
b = []
38-
if b:
40+
if drop_last == False and len(b) != 0:
3941
yield b
4042

4143
return batch_reader

0 commit comments

Comments
 (0)