File tree Expand file tree Collapse file tree 2 files changed +8
-4
lines changed Expand file tree Collapse file tree 2 files changed +8
-4
lines changed Original file line number Diff line number Diff line change 15
15
__all__ = ['batch' ]
16
16
17
17
18
- def batch (reader , batch_size ):
18
+ def batch (reader , batch_size , drop_last = False ):
19
19
"""
20
20
Create a batched reader.
21
21
22
22
:param reader: the data reader to read from.
23
23
:type reader: callable
24
24
:param batch_size: size of each mini-batch
25
25
:type batch_size: int
26
+ :param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
27
+ :type drop_last: bool
26
28
:return: the batched reader.
27
29
:rtype: callable
28
30
"""
@@ -35,7 +37,7 @@ def batch_reader():
35
37
if len (b ) == batch_size :
36
38
yield b
37
39
b = []
38
- if b :
40
+ if drop_last == False and len ( b ) != 0 :
39
41
yield b
40
42
41
43
return batch_reader
Original file line number Diff line number Diff line change 15
15
__all__ = ['batch' ]
16
16
17
17
18
- def batch (reader , batch_size ):
18
+ def batch (reader , batch_size , drop_last = False ):
19
19
"""
20
20
Create a batched reader.
21
21
22
22
:param reader: the data reader to read from.
23
23
:type reader: callable
24
24
:param batch_size: size of each mini-batch
25
25
:type batch_size: int
26
+ :param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
27
+ :type drop_last: bool
26
28
:return: the batched reader.
27
29
:rtype: callable
28
30
"""
@@ -35,7 +37,7 @@ def batch_reader():
35
37
if len (b ) == batch_size :
36
38
yield b
37
39
b = []
38
- if b :
40
+ if drop_last == False and len ( b ) != 0 :
39
41
yield b
40
42
41
43
return batch_reader
You can’t perform that action at this time.
0 commit comments