Skip to content

Commit ca62c10

Browse files
authored
Merge pull request #1564 from reyoung/feature/rename_reader_dict_to_feeding
Feature/rename reader dict to feeding
2 parents 963bd5d + 2644536 commit ca62c10

File tree

9 files changed

+64
-76
lines changed

9 files changed

+64
-76
lines changed

demo/image_classification/api_v2_train.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License
1414

1515
import sys
16+
1617
import paddle.v2 as paddle
17-
from api_v2_vgg import vgg_bn_drop
18+
1819
from api_v2_resnet import resnet_cifar10
1920

2021

@@ -23,7 +24,7 @@ def main():
2324
classdim = 10
2425

2526
# PaddlePaddle init
26-
paddle.init(use_gpu=True, trainer_count=1)
27+
paddle.init(use_gpu=False, trainer_count=1)
2728

2829
image = paddle.layer.data(
2930
name="image", type=paddle.data_type.dense_vector(datadim))
@@ -68,8 +69,8 @@ def event_handler(event):
6869
result = trainer.test(
6970
reader=paddle.batch(
7071
paddle.dataset.cifar.test10(), batch_size=128),
71-
reader_dict={'image': 0,
72-
'label': 1})
72+
feeding={'image': 0,
73+
'label': 1})
7374
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
7475

7576
# Create trainer
@@ -83,8 +84,8 @@ def event_handler(event):
8384
batch_size=128),
8485
num_passes=5,
8586
event_handler=event_handler,
86-
reader_dict={'image': 0,
87-
'label': 1})
87+
feeding={'image': 0,
88+
'label': 1})
8889

8990

9091
if __name__ == '__main__':

demo/introduction/api_train_v2.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,26 @@ def main():
3030
def event_handler(event):
3131
if isinstance(event, paddle.event.EndIteration):
3232
if event.batch_id % 100 == 0:
33-
print "Pass %d, Batch %d, Cost %f, %s" % (
34-
event.pass_id, event.batch_id, event.cost, event.metrics)
33+
print "Pass %d, Batch %d, Cost %f" % (
34+
event.pass_id, event.batch_id, event.cost)
3535

3636
if isinstance(event, paddle.event.EndPass):
37-
result = trainer.test(
38-
reader=paddle.reader.batched(
39-
uci_housing.test(), batch_size=2),
40-
reader_dict={'x': 0,
37+
if (event.pass_id + 1) % 10 == 0:
38+
result = trainer.test(
39+
reader=paddle.batch(
40+
uci_housing.test(), batch_size=2),
41+
feeding={'x': 0,
4142
'y': 1})
42-
if event.pass_id % 10 == 0:
43-
print "Test %d, %s" % (event.pass_id, result.metrics)
43+
print "Test %d, %.2f" % (event.pass_id, result.cost)
4444

4545
# training
4646
trainer.train(
47-
reader=paddle.reader.batched(
47+
reader=paddle.batch(
4848
paddle.reader.shuffle(
4949
uci_housing.train(), buf_size=500),
5050
batch_size=2),
51-
reader_dict={'x': 0,
52-
'y': 1},
51+
feeding={'x': 0,
52+
'y': 1},
5353
event_handler=event_handler,
5454
num_passes=30)
5555

demo/mnist/api_train_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def main():
9292
def event_handler(event):
9393
if isinstance(event, paddle.event.EndIteration):
9494
if event.batch_id % 1000 == 0:
95-
result = trainer.test(reader=paddle.reader.batched(
95+
result = trainer.test(reader=paddle.batch(
9696
paddle.dataset.mnist.test(), batch_size=256))
9797

9898
print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
@@ -103,7 +103,7 @@ def event_handler(event):
103103
parameters.to_tar(f)
104104

105105
elif isinstance(event, paddle.event.EndPass):
106-
result = trainer.test(reader=paddle.reader.batched(
106+
result = trainer.test(reader=paddle.batch(
107107
paddle.dataset.mnist.test(), batch_size=128))
108108
print "Test with Pass %d, Cost %f, %s\n" % (
109109
event.pass_id, result.cost, result.metrics)

demo/semantic_role_labeling/api_train_v2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,11 @@ def event_handler(event):
163163
update_equation=optimizer)
164164
parameters.set('emb', load_parameter(conll05.get_embedding(), 44068, 32))
165165

166-
trn_reader = paddle.reader.batched(
166+
trn_reader = paddle.batch(
167167
paddle.reader.shuffle(
168168
conll05.test(), buf_size=8192), batch_size=10)
169169

170-
reader_dict = {
170+
feeding = {
171171
'word_data': 0,
172172
'ctx_n2_data': 1,
173173
'ctx_n1_data': 2,
@@ -183,7 +183,7 @@ def event_handler(event):
183183
reader=trn_reader,
184184
event_handler=event_handler,
185185
num_passes=10000,
186-
reader_dict=reader_dict)
186+
feeding=feeding)
187187

188188

189189
if __name__ == '__main__':

demo/sentiment/train_v2.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@
1818
import paddle.v2 as paddle
1919

2020

21-
def convolution_net(input_dim,
22-
class_dim=2,
23-
emb_dim=128,
24-
hid_dim=128,
25-
is_predict=False):
21+
def convolution_net(input_dim, class_dim=2, emb_dim=128, hid_dim=128):
2622
data = paddle.layer.data("word",
2723
paddle.data_type.integer_value_sequence(input_dim))
2824
emb = paddle.layer.embedding(input=data, size=emb_dim)
@@ -42,8 +38,7 @@ def stacked_lstm_net(input_dim,
4238
class_dim=2,
4339
emb_dim=128,
4440
hid_dim=512,
45-
stacked_num=3,
46-
is_predict=False):
41+
stacked_num=3):
4742
"""
4843
A Wrapper for sentiment classification task.
4944
This network uses bi-directional recurrent network,
@@ -110,7 +105,7 @@ def stacked_lstm_net(input_dim,
110105

111106
if __name__ == '__main__':
112107
# init
113-
paddle.init(use_gpu=True, trainer_count=4)
108+
paddle.init(use_gpu=False, trainer_count=4)
114109

115110
# network config
116111
print 'load dictionary...'
@@ -143,11 +138,11 @@ def event_handler(event):
143138
sys.stdout.flush()
144139
if isinstance(event, paddle.event.EndPass):
145140
result = trainer.test(
146-
reader=paddle.reader.batched(
141+
reader=paddle.batch(
147142
lambda: paddle.dataset.imdb.test(word_dict),
148143
batch_size=128),
149-
reader_dict={'word': 0,
150-
'label': 1})
144+
feeding={'word': 0,
145+
'label': 1})
151146
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
152147

153148
# create trainer
@@ -156,11 +151,11 @@ def event_handler(event):
156151
update_equation=adam_optimizer)
157152

158153
trainer.train(
159-
reader=paddle.reader.batched(
154+
reader=paddle.batch(
160155
paddle.reader.shuffle(
161156
lambda: paddle.dataset.imdb.train(word_dict), buf_size=1000),
162157
batch_size=100),
163158
event_handler=event_handler,
164-
reader_dict={'word': 0,
165-
'label': 1},
159+
feeding={'word': 0,
160+
'label': 1},
166161
num_passes=10)

demo/seqToseq/api_train_v2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,13 @@ def main():
8080
update_equation=optimizer)
8181

8282
# define data reader
83-
reader_dict = {
83+
feeding = {
8484
'source_language_word': 0,
8585
'target_language_word': 1,
8686
'target_language_next_word': 2
8787
}
8888

89-
wmt14_reader = paddle.reader.batched(
89+
wmt14_reader = paddle.batch(
9090
paddle.reader.shuffle(
9191
train_reader("data/pre-wmt14/train/train"), buf_size=8192),
9292
batch_size=5)
@@ -103,7 +103,7 @@ def event_handler(event):
103103
reader=wmt14_reader,
104104
event_handler=event_handler,
105105
num_passes=10000,
106-
reader_dict=reader_dict)
106+
feeding=feeding)
107107

108108

109109
if __name__ == '__main__':

python/paddle/v2/data_feeder.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,18 @@
1414

1515
from py_paddle import DataProviderConverter
1616

17-
import data_type
17+
import paddle.trainer.PyDataProvider2 as pydp2
1818

1919
__all__ = ['DataFeeder']
2020

2121

22+
def default_feeding_map(data_types):
23+
reader_dict = dict()
24+
for i, tp in enumerate(data_types):
25+
reader_dict[tp[0]] = i
26+
return reader_dict
27+
28+
2229
class DataFeeder(DataProviderConverter):
2330
"""
2431
DataFeeder converts the data returned by paddle.reader into a data structure
@@ -60,16 +67,21 @@ class DataFeeder(DataProviderConverter):
6067
:type data_types: list
6168
:param reader_dict: A dictionary to specify the position of each data
6269
in the input data.
63-
:type reader_dict: dict
70+
:type feeding: dict
6471
"""
6572

66-
def __init__(self, data_types, reader_dict):
73+
def __init__(self, data_types, feeding=None):
6774
self.input_names = []
6875
input_types = []
69-
self.reader_dict = reader_dict
76+
if feeding is None:
77+
feeding = default_feeding_map(data_types)
78+
79+
self.feeding = feeding
7080
for each in data_types:
7181
self.input_names.append(each[0])
72-
assert isinstance(each[1], data_type.InputType)
82+
if not isinstance(each[1], pydp2.InputType):
83+
raise TypeError("second item in each data_type should be an "
84+
"InputType")
7385
input_types.append(each[1])
7486
DataProviderConverter.__init__(self, input_types)
7587

@@ -90,7 +102,7 @@ def reorder_data(data):
90102
for each in data:
91103
reorder = []
92104
for name in self.input_names:
93-
reorder.append(each[self.reader_dict[name]])
105+
reorder.append(each[self.feeding[name]])
94106
retv.append(reorder)
95107
return retv
96108

python/paddle/v2/inference.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@ def __init__(self, output, parameters):
2121
self.__gradient_machine__ = gm
2222
self.__data_types__ = topo.data_type()
2323

24-
def iter_infer(self, reader, reader_dict=None):
25-
if reader_dict is None:
26-
reader_dict = self.default_reader_dict()
27-
feeder = DataFeeder(self.__data_types__, reader_dict)
24+
def iter_infer(self, reader, feeding=None):
25+
feeder = DataFeeder(self.__data_types__, feeding)
2826
self.__gradient_machine__.start()
2927
for data_batch in reader():
3028
yield self.__gradient_machine__.forwardTest(feeder(data_batch))
@@ -47,13 +45,7 @@ def infer(self, field='value', **kwargs):
4745
else:
4846
return retv
4947

50-
def default_reader_dict(self):
51-
reader_dict = dict()
52-
for i, tp in enumerate(self.__data_types__):
53-
reader_dict[tp[0]] = i
54-
return reader_dict
5548

56-
57-
def infer(output, parameters, reader, reader_dict=None, field='value'):
49+
def infer(output, parameters, reader, feeding=None, field='value'):
5850
inferer = Inference(output=output, parameters=parameters)
59-
return inferer.infer(field=field, reader=reader, reader_dict=reader_dict)
51+
return inferer.infer(field=field, reader=reader, feeding=feeding)

python/paddle/v2/trainer.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, cost, parameters, update_equation):
6161
self.__gradient_machine__.randParameters()
6262
parameters.append_gradient_machine(gm)
6363

64-
def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
64+
def train(self, reader, num_passes=1, event_handler=None, feeding=None):
6565
"""
6666
Training method. Will train num_passes of input data.
6767
@@ -70,14 +70,13 @@ def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
7070
:param event_handler: Event handler. A method will be invoked when event
7171
occurred.
7272
:type event_handler: (BaseEvent) => None
73+
:param feeding: Feeding is a map of neural network input name and array
74+
index that reader returns.
75+
:type feeding: dict
7376
:return:
7477
"""
7578
if event_handler is None:
7679
event_handler = default_event_handler
77-
78-
if reader_dict is None:
79-
reader_dict = self.default_reader_dict()
80-
8180
__check_train_args__(**locals())
8281

8382
updater = self.__optimizer__.create_local_updater()
@@ -89,9 +88,7 @@ def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
8988
pass_evaluator = self.__gradient_machine__.makeEvaluator()
9089
assert isinstance(pass_evaluator, api.Evaluator)
9190
out_args = api.Arguments.createArguments(0)
92-
93-
feeder = DataFeeder(self.__data_types__, reader_dict)
94-
91+
feeder = DataFeeder(self.__data_types__, feeding)
9592
for pass_id in xrange(num_passes):
9693
event_handler(v2_event.BeginPass(pass_id))
9794
pass_evaluator.start()
@@ -125,17 +122,8 @@ def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
125122
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
126123
self.__gradient_machine__.finish()
127124

128-
def default_reader_dict(self):
129-
reader_dict = dict()
130-
for i, tp in enumerate(self.__data_types__):
131-
reader_dict[tp[0]] = i
132-
return reader_dict
133-
134-
def test(self, reader, reader_dict=None):
135-
if reader_dict is None:
136-
reader_dict = self.default_reader_dict()
137-
138-
feeder = DataFeeder(self.__data_types__, reader_dict)
125+
def test(self, reader, feeding=None):
126+
feeder = DataFeeder(self.__data_types__, feeding)
139127
evaluator = self.__gradient_machine__.makeEvaluator()
140128
out_args = api.Arguments.createArguments(0)
141129
evaluator.start()

0 commit comments

Comments
 (0)