Skip to content

Commit b265cca

Browse files
authored
Merge pull request #1464 from reyoung/feature/clean_mnist_v2
Combine Reader/Feeder together in trainer.train
2 parents ce32599 + eee1320 commit b265cca

File tree

17 files changed

+124
-109
lines changed

17 files changed

+124
-109
lines changed

demo/mnist/api_train_v2.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,5 @@
11
import paddle.v2 as paddle
22

3-
import mnist_util
4-
5-
6-
def train_reader():
7-
train_file = './data/raw_data/train'
8-
generator = mnist_util.read_from_mnist(train_file)
9-
for item in generator:
10-
yield item
11-
123

134
def main():
145
paddle.init(use_gpu=False, trainer_count=1)
@@ -40,11 +31,13 @@ def event_handler(event):
4031
trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
4132

4233
trainer.train(
43-
train_data_reader=train_reader,
34+
reader=paddle.reader.batched(
35+
paddle.reader.shuffle(
36+
paddle.dataset.mnist.train(), buf_size=8192),
37+
batch_size=32),
4438
cost=cost,
4539
parameters=parameters,
4640
event_handler=event_handler,
47-
batch_size=32, # batch size should be refactor in Data reader
4841
reader_dict={images.name: 0,
4942
label.name: 1})
5043

python/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ set(OUTPUT_DIR
44
file(GLOB TRAINER_PY_FILES . ./paddle/trainer/*.py)
55
file(GLOB HELPERS_PY_FILES . ./paddle/trainer_config_helpers/*.py)
66
file(GLOB UTILS_PY_FILES . ./paddle/utils/*.py)
7-
file(GLOB V2_PY_FILES . ./paddle/v2/*.py)
7+
file(GLOB_RECURSE V2_PY_FILES ./paddle/v2/ *.py)
88

99
set(PY_FILES paddle/__init__.py
1010
${TRAINER_PY_FILES}
@@ -24,7 +24,7 @@ add_custom_target(paddle_python ALL DEPENDS
2424
${OUTPUT_DIR}/.timestamp)
2525

2626
add_subdirectory(paddle/trainer_config_helpers/tests)
27-
add_subdirectory(paddle/reader/tests)
27+
add_subdirectory(paddle/v2/reader/tests)
2828
add_subdirectory(paddle/v2/tests)
2929

3030
install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/dist/

python/paddle/reader/tests/CMakeLists.txt

Lines changed: 0 additions & 9 deletions
This file was deleted.

python/paddle/v2/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@
2020
import data_type
2121
import topology
2222
import data_feeder
23+
from . import dataset
24+
from . import reader
2325
import attr
2426
import pooling
2527
import py_paddle.swig_paddle as api
2628

2729
__all__ = [
2830
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
29-
'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'topology'
31+
'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader',
32+
'topology'
3033
]
3134

3235

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import mnist
2+
3+
__all__ = ['mnist']

python/paddle/v2/dataset/mnist.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
22
MNIST dataset.
33
"""
4-
import numpy
54
import paddle.v2.dataset.common
65
import subprocess
7-
6+
import numpy
7+
import platform
88
__all__ = ['train', 'test']
99

1010
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
@@ -20,12 +20,19 @@
2020

2121
def reader_creator(image_filename, label_filename, buffer_size):
2222
def reader():
23+
if platform.system() == 'Darwin':
24+
zcat_cmd = 'gzcat'
25+
elif platform.system() == 'Linux':
26+
zcat_cmd = 'zcat'
27+
else:
28+
raise NotImplementedError()
29+
2330
# According to http://stackoverflow.com/a/38061619/724872, we
2431
# cannot use standard package gzip here.
25-
m = subprocess.Popen(["zcat", image_filename], stdout=subprocess.PIPE)
32+
m = subprocess.Popen([zcat_cmd, image_filename], stdout=subprocess.PIPE)
2633
m.stdout.read(16) # skip some magic bytes
2734

28-
l = subprocess.Popen(["zcat", label_filename], stdout=subprocess.PIPE)
35+
l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE)
2936
l.stdout.read(8) # skip some magic bytes
3037

3138
while True:
File renamed without changes.
File renamed without changes.

python/paddle/reader/decorator.py renamed to python/paddle/v2/reader/decorator.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
__all__ = [
1616
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
17-
'ComposeNotAligned'
17+
'ComposeNotAligned', 'batched'
1818
]
1919

2020
from Queue import Queue
@@ -191,3 +191,25 @@ def data_reader():
191191
e = q.get()
192192

193193
return data_reader
194+
195+
196+
def batched(reader, batch_size):
197+
"""
198+
Create a batched reader.
199+
:param reader: the data reader to read from.
200+
:param batch_size: batch_size
201+
:return: the batched reader.
202+
"""
203+
204+
def batched_reader():
205+
r = reader()
206+
batch = []
207+
for instance in r:
208+
batch.append(instance)
209+
if len(batch) == batch_size:
210+
yield batch
211+
batch = []
212+
if batch:
213+
yield batch
214+
215+
return batched_reader
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
add_test(NAME reader_tests
2+
COMMAND bash ${PROJ_ROOT}/python/paddle/v2/reader/tests/run_tests.sh
3+
${PYTHON_EXECUTABLE})

0 commit comments

Comments
 (0)