Skip to content

Commit 2a3a1e9

Browse files
authored
Add DataFeeder (#6102)
* Add DataFeeder A v2 API like data feeder for book demos. We can feed data directly from reader. * Fix CI * Remove batch_size_dim for feeder Also add __all__ to data_feeder.py * Follow comment
1 parent 09fc307 commit 2a3a1e9

11 files changed

+177
-128
lines changed

python/paddle/v2/fluid/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,21 @@
1414
import backward
1515
import regularizer
1616
from param_attr import ParamAttr
17-
17+
from data_feeder import DataFeeder
1818
from core import LoDTensor, CPUPlace, GPUPlace
1919

2020
Tensor = LoDTensor
2121
__all__ = framework.__all__ + executor.__all__ + [
2222
'io', 'initializer', 'layers', 'nets', 'optimizer', 'backward',
2323
'regularizer', 'LoDTensor', 'CPUPlace', 'GPUPlace', 'Tensor', 'ParamAttr'
24+
'DataFeeder'
2425
]
2526

2627

2728
def __read_gflags_from_env__():
2829
"""
2930
Enable reading gflags from environment variables.
30-
31+
3132
Returns:
3233
None
3334
"""

python/paddle/v2/fluid/data_feeder.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from __future__ import print_function
2+
3+
import core
4+
import numpy
5+
import six.moves as six
6+
7+
from framework import Variable
8+
9+
__all__ = ['DataFeeder']
10+
11+
12+
class DataToLoDTensorConverter(object):
13+
def __init__(self, place, lod_level, shape, dtype):
14+
self.place = place
15+
self.lod_level = lod_level
16+
self.shape = shape
17+
if dtype == core.DataType.FP32:
18+
self.dtype = 'float32'
19+
elif dtype == core.DataType.INT64:
20+
self.dtype = 'int64'
21+
elif dtype == core.DataType.FP64:
22+
self.dtype = 'float64'
23+
elif dtype == core.DataType.INT32:
24+
self.dtype = 'int32'
25+
else:
26+
raise ValueError("dtype must be any of [int32, float32, int64, "
27+
"float64]")
28+
29+
self.data = []
30+
self.lod = []
31+
32+
for i in six.range(lod_level):
33+
self.lod.append([0])
34+
35+
def feed(self, data):
36+
self._feed_impl_(data, self.lod, self.lod_level)
37+
38+
def _feed_impl_(self, data, lod, lod_level):
39+
if lod_level == 0:
40+
self.data.append(data)
41+
else:
42+
cur_lod_len = len(data)
43+
lod[-1].append(lod[-1][-1] + cur_lod_len)
44+
for each_data in data:
45+
self._feed_impl_(each_data, lod[:-1], lod_level - 1)
46+
47+
def done(self):
48+
arr = numpy.array(self.data, dtype=self.dtype).reshape(self.shape)
49+
t = core.LoDTensor()
50+
t.set(arr, self.place)
51+
if self.lod_level > 0:
52+
t.set_lod(self.lod)
53+
return t
54+
55+
56+
class DataFeeder(object):
57+
def __init__(self, feed_list, place):
58+
self.feed_dtypes = []
59+
self.feed_names = []
60+
self.feed_shapes = []
61+
self.feed_lod_level = []
62+
for each_var in feed_list:
63+
if not isinstance(each_var, Variable):
64+
raise TypeError("Feed list should contain a list of variable")
65+
self.feed_dtypes.append(each_var.dtype)
66+
self.feed_names.append(each_var.name)
67+
shape = each_var.shape
68+
batch_size_dim = -1
69+
for i, s in enumerate(shape):
70+
if s < 0:
71+
batch_size_dim = i
72+
break
73+
if batch_size_dim == -1:
74+
raise ValueError("Variable {0} must has a batch size dimension",
75+
each_var.name)
76+
self.feed_lod_level.append(each_var.lod_level)
77+
self.feed_shapes.append(shape)
78+
79+
self.place = place
80+
81+
def feed(self, iterable):
82+
converter = []
83+
for lod_level, shape, dtype in six.zip(
84+
self.feed_lod_level, self.feed_shapes, self.feed_dtypes):
85+
converter.append(
86+
DataToLoDTensorConverter(
87+
place=self.place,
88+
lod_level=lod_level,
89+
shape=shape,
90+
dtype=dtype))
91+
92+
for each_sample in iterable:
93+
for each_converter, each_slot in six.zip(converter, each_sample):
94+
each_converter.feed(each_slot)
95+
ret_dict = {}
96+
for each_name, each_converter in six.zip(self.feed_names, converter):
97+
ret_dict[each_name] = each_converter.done()
98+
return ret_dict

python/paddle/v2/fluid/tests/book/test_fit_a_line.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
batch_size=BATCH_SIZE)
2323

2424
place = fluid.CPUPlace()
25+
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
2526
exe = fluid.Executor(place)
2627

2728
exe.run(fluid.default_startup_program())
@@ -31,12 +32,8 @@
3132
fluid.io.save_persistables(exe, "./fit_a_line.model/")
3233
fluid.io.load_persistables(exe, "./fit_a_line.model/")
3334
for data in train_reader():
34-
x_data = np.array(map(lambda _: _[0], data)).astype("float32")
35-
y_data = np.array(map(lambda _: _[1], data)).astype("float32")
36-
3735
avg_loss_value, = exe.run(fluid.default_main_program(),
38-
feed={'x': x_data,
39-
'y': y_data},
36+
feed=feeder.feed(data),
4037
fetch_list=[avg_cost])
4138

4239
if avg_loss_value[0] < 10.0:

python/paddle/v2/fluid/tests/book/test_image_classification_train.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,23 +113,14 @@ def conv_block(input, num_filter, groups, dropouts):
113113

114114
place = fluid.CPUPlace()
115115
exe = fluid.Executor(place)
116-
116+
feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
117117
exe.run(fluid.default_startup_program())
118118

119119
for pass_id in range(PASS_NUM):
120120
accuracy.reset(exe)
121121
for data in train_reader():
122-
img_data = np.array(map(lambda x: x[0].reshape(data_shape),
123-
data)).astype("float32")
124-
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
125-
batch_size = 1
126-
for i in y_data.shape:
127-
batch_size = batch_size * i
128-
y_data = y_data.reshape([batch_size, 1])
129-
130122
loss, acc = exe.run(fluid.default_main_program(),
131-
feed={"pixel": img_data,
132-
"label": y_data},
123+
feed=feeder.feed(data),
133124
fetch_list=[avg_cost] + accuracy.metrics)
134125
pass_acc = accuracy.eval(exe)
135126
print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str(

python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,9 @@ def load_parameter(file_name, h, w):
2828
return np.fromfile(f, dtype=np.float32).reshape(h, w)
2929

3030

31-
def db_lstm():
31+
def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
32+
**ignored):
3233
# 8 features
33-
word = fluid.layers.data(name='word_data', shape=[1], dtype='int64')
34-
predicate = fluid.layers.data(name='verb_data', shape=[1], dtype='int64')
35-
ctx_n2 = fluid.layers.data(name='ctx_n2_data', shape=[1], dtype='int64')
36-
ctx_n1 = fluid.layers.data(name='ctx_n1_data', shape=[1], dtype='int64')
37-
ctx_0 = fluid.layers.data(name='ctx_0_data', shape=[1], dtype='int64')
38-
ctx_p1 = fluid.layers.data(name='ctx_p1_data', shape=[1], dtype='int64')
39-
ctx_p2 = fluid.layers.data(name='ctx_p2_data', shape=[1], dtype='int64')
40-
mark = fluid.layers.data(name='mark_data', shape=[1], dtype='int64')
41-
4234
predicate_embedding = fluid.layers.embedding(
4335
input=predicate,
4436
size=[pred_len, word_dim],
@@ -120,8 +112,25 @@ def to_lodtensor(data, place):
120112

121113
def main():
122114
# define network topology
123-
feature_out = db_lstm()
124-
target = fluid.layers.data(name='target', shape=[1], dtype='int64')
115+
word = fluid.layers.data(
116+
name='word_data', shape=[1], dtype='int64', lod_level=1)
117+
predicate = fluid.layers.data(
118+
name='verb_data', shape=[1], dtype='int64', lod_level=1)
119+
ctx_n2 = fluid.layers.data(
120+
name='ctx_n2_data', shape=[1], dtype='int64', lod_level=1)
121+
ctx_n1 = fluid.layers.data(
122+
name='ctx_n1_data', shape=[1], dtype='int64', lod_level=1)
123+
ctx_0 = fluid.layers.data(
124+
name='ctx_0_data', shape=[1], dtype='int64', lod_level=1)
125+
ctx_p1 = fluid.layers.data(
126+
name='ctx_p1_data', shape=[1], dtype='int64', lod_level=1)
127+
ctx_p2 = fluid.layers.data(
128+
name='ctx_p2_data', shape=[1], dtype='int64', lod_level=1)
129+
mark = fluid.layers.data(
130+
name='mark_data', shape=[1], dtype='int64', lod_level=1)
131+
feature_out = db_lstm(**locals())
132+
target = fluid.layers.data(
133+
name='target', shape=[1], dtype='int64', lod_level=1)
125134
crf_cost = fluid.layers.linear_chain_crf(
126135
input=feature_out,
127136
label=target,
@@ -139,6 +148,11 @@ def main():
139148
paddle.dataset.conll05.test(), buf_size=8192),
140149
batch_size=BATCH_SIZE)
141150
place = fluid.CPUPlace()
151+
feeder = fluid.DataFeeder(
152+
feed_list=[
153+
word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, predicate, mark, target
154+
],
155+
place=place)
142156
exe = fluid.Executor(place)
143157

144158
exe.run(fluid.default_startup_program())
@@ -150,28 +164,8 @@ def main():
150164
batch_id = 0
151165
for pass_id in xrange(PASS_NUM):
152166
for data in train_data():
153-
word_data = to_lodtensor(map(lambda x: x[0], data), place)
154-
ctx_n2_data = to_lodtensor(map(lambda x: x[1], data), place)
155-
ctx_n1_data = to_lodtensor(map(lambda x: x[2], data), place)
156-
ctx_0_data = to_lodtensor(map(lambda x: x[3], data), place)
157-
ctx_p1_data = to_lodtensor(map(lambda x: x[4], data), place)
158-
ctx_p2_data = to_lodtensor(map(lambda x: x[5], data), place)
159-
verb_data = to_lodtensor(map(lambda x: x[6], data), place)
160-
mark_data = to_lodtensor(map(lambda x: x[7], data), place)
161-
target = to_lodtensor(map(lambda x: x[8], data), place)
162-
163167
outs = exe.run(fluid.default_main_program(),
164-
feed={
165-
'word_data': word_data,
166-
'ctx_n2_data': ctx_n2_data,
167-
'ctx_n1_data': ctx_n1_data,
168-
'ctx_0_data': ctx_0_data,
169-
'ctx_p1_data': ctx_p1_data,
170-
'ctx_p2_data': ctx_p2_data,
171-
'verb_data': verb_data,
172-
'mark_data': mark_data,
173-
'target': target
174-
},
168+
feed=feeder.feed(data),
175169
fetch_list=[avg_cost])
176170
avg_cost_val = np.array(outs[0])
177171

python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,14 @@
3737

3838
place = fluid.CPUPlace()
3939
exe = fluid.Executor(place)
40-
40+
feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
4141
exe.run(fluid.default_startup_program())
4242

4343
for pass_id in range(PASS_NUM):
4444
accuracy.reset(exe)
4545
for data in train_reader():
46-
img_data = np.array(map(lambda x: x[0].reshape([1, 28, 28]),
47-
data)).astype("float32")
48-
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
49-
y_data = y_data.reshape([BATCH_SIZE, 1])
50-
5146
loss, acc = exe.run(fluid.default_main_program(),
52-
feed={"pixel": img_data,
53-
"label": y_data},
47+
feed=feeder.feed(data),
5448
fetch_list=[avg_cost] + accuracy.metrics)
5549
pass_acc = accuracy.eval(exe)
5650
print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" +

python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,40 +48,22 @@
4848

4949
place = fluid.CPUPlace()
5050
exe = fluid.Executor(place)
51-
51+
feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
5252
exe.run(fluid.default_startup_program())
5353

5454
PASS_NUM = 100
5555
for pass_id in range(PASS_NUM):
5656
accuracy.reset(exe)
5757
for data in train_reader():
58-
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
59-
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
60-
y_data = np.expand_dims(y_data, axis=1)
61-
62-
tensor_x = fluid.LoDTensor()
63-
tensor_x.set(x_data, place)
64-
65-
tensor_y = fluid.LoDTensor()
66-
tensor_y.set(y_data, place)
67-
68-
outs = exe.run(fluid.default_main_program(),
69-
feed={'x': tensor_x,
70-
'y': tensor_y},
71-
fetch_list=[avg_cost] + accuracy.metrics)
72-
out = np.array(outs[0])
73-
acc = np.array(outs[1])
58+
out, acc = exe.run(fluid.default_main_program(),
59+
feed=feeder.feed(data),
60+
fetch_list=[avg_cost] + accuracy.metrics)
7461
pass_acc = accuracy.eval(exe)
7562

7663
test_accuracy.reset(exe)
7764
for data in test_reader():
78-
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
79-
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
80-
y_data = np.expand_dims(y_data, axis=1)
81-
8265
out, acc = exe.run(inference_program,
83-
feed={'x': x_data,
84-
'y': y_data},
66+
feed=feeder.feed(data),
8567
fetch_list=[avg_cost] + test_accuracy.metrics)
8668

8769
test_pass_acc = test_accuracy.eval(exe)

python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
import paddle.v2.fluid as fluid
55

66

7-
def convolution_net(input_dim, class_dim=2, emb_dim=32, hid_dim=32):
8-
data = fluid.layers.data(name="words", shape=[1], dtype="int64")
9-
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
10-
7+
def convolution_net(data, label, input_dim, class_dim=2, emb_dim=32,
8+
hid_dim=32):
119
emb = fluid.layers.embedding(input=data, size=[input_dim, emb_dim])
1210
conv_3 = fluid.nets.sequence_conv_pool(
1311
input=emb,
@@ -55,34 +53,28 @@ def main():
5553
dict_dim = len(word_dict)
5654
class_dim = 2
5755

56+
data = fluid.layers.data(
57+
name="words", shape=[1], dtype="int64", lod_level=1)
58+
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
5859
cost, accuracy, acc_out = convolution_net(
59-
input_dim=dict_dim, class_dim=class_dim)
60+
data, label, input_dim=dict_dim, class_dim=class_dim)
6061

6162
train_data = paddle.batch(
6263
paddle.reader.shuffle(
6364
paddle.dataset.imdb.train(word_dict), buf_size=1000),
6465
batch_size=BATCH_SIZE)
6566
place = fluid.CPUPlace()
6667
exe = fluid.Executor(place)
68+
feeder = fluid.DataFeeder(feed_list=[data, label], place=place)
6769

6870
exe.run(fluid.default_startup_program())
6971

7072
for pass_id in xrange(PASS_NUM):
7173
accuracy.reset(exe)
7274
for data in train_data():
73-
tensor_words = to_lodtensor(map(lambda x: x[0], data), place)
74-
75-
label = np.array(map(lambda x: x[1], data)).astype("int64")
76-
label = label.reshape([BATCH_SIZE, 1])
77-
78-
tensor_label = fluid.LoDTensor()
79-
tensor_label.set(label, place)
80-
81-
cost_val, acc_val = exe.run(
82-
fluid.default_main_program(),
83-
feed={"words": tensor_words,
84-
"label": tensor_label},
85-
fetch_list=[cost, acc_out])
75+
cost_val, acc_val = exe.run(fluid.default_main_program(),
76+
feed=feeder.feed(data),
77+
fetch_list=[cost, acc_out])
8678
pass_acc = accuracy.eval(exe)
8779
print("cost=" + str(cost_val) + " acc=" + str(acc_val) +
8880
" pass_acc=" + str(pass_acc))

0 commit comments

Comments
 (0)