Skip to content

Commit 4b95095

Browse files
committed
Add unittests and fix a few bugs
1 parent ba53801 commit 4b95095

File tree

6 files changed

+206
-8
lines changed

6 files changed

+206
-8
lines changed

paddle/fluid/framework/details/data_balance_op_handle.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ void DataBalanceOpHandle::RunImpl() {
107107
auto *tensor_var = local_scope->FindVar(in_var_handles[i]->name_);
108108
PADDLE_ENFORCE(tensor_var->IsType<LoDTensor>());
109109
auto *tensor = tensor_var->GetMutable<LoDTensor>();
110-
PADDLE_ENFORCE(places_[place_idx] == tensor->place());
111110
lod_tensors[data_idx].push_back(tensor);
112111
int ins_size =
113112
tensor->lod().empty() ? tensor->dims()[0] : tensor->NumElements();

paddle/fluid/framework/details/fetch_op_handle.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ void FetchOpHandle::RunImpl() {
6767
#endif
6868
} else {
6969
tensors_[i].ShareDataWith(t);
70-
tensors_[i].set_lod(t.lod());
7170
}
71+
tensors_[i].set_lod(t.lod());
7272
}
7373

7474
this->WaitAndMergeCPUTensors();

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
216216
} else {
217217
// This op runs on all devices, and its output may have parameter's
218218
// gradients.
219-
CreateComputationalOps(&result, *op, places_.size());
220-
221219
if (op->Type() == "read") {
220+
op->SetAttr("throw_eof_exp", false);
221+
CreateComputationalOps(&result, *op, places_.size());
222222
const auto &data_var_names = op->Output("Out");
223223
InsertDataBalanceOp(&result, data_var_names);
224+
} else {
225+
CreateComputationalOps(&result, *op, places_.size());
224226
}
225227

226228
if (!is_forwarding && places_.size() > 1) {

paddle/fluid/framework/lod_tensor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ void LoDTensor::MergeLoDTensor(
393393
new_dim[0] += t->dims()[0];
394394

395395
auto &lod = t->lod();
396+
PADDLE_ENFORCE_EQ(new_lod.size(), lod.size());
396397
for (size_t j = 0; j < lod.size(); ++j) {
397398
auto &sub_lod = new_lod[j];
398399
auto &offset = sub_lod.back();

paddle/fluid/operators/read_op.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,14 @@ class ReadOp : public framework::OperatorBase {
6767
std::vector<framework::LoDTensor> ins;
6868
reader->ReadNext(&ins);
6969
if (ins.empty()) {
70-
ins.resize(out_arg_names.size());
71-
for (auto& tensor : ins) {
72-
// data type is not important for subsequent DataBalanceOpHandle
73-
tensor.mutable_data<float>(framework::make_ddim({0}), dev_place);
70+
if (Attr<bool>("throw_eof_exp")) {
71+
PADDLE_THROW("There is no next data.");
72+
} else {
73+
ins.resize(out_arg_names.size());
74+
for (auto& tensor : ins) {
75+
// data type is not important for subsequent DataBalanceOpHandle
76+
tensor.mutable_data<float>(framework::make_ddim({0}), dev_place);
77+
}
7478
}
7579
}
7680
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
@@ -88,6 +92,10 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
8892
void Make() override {
8993
AddInput("Reader", "(ReaderHolder) The executed reader.");
9094
AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable();
95+
AddAttr<bool>("throw_eof_exp",
96+
"If set true, an exception will be thrown when the Reader "
97+
"yields empty (which means there is no next data).")
98+
.SetDefault(true);
9199
AddComment(R"DOC(
92100
Read Operator
93101
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import paddle.fluid as fluid
17+
import paddle.v2 as paddle
18+
import paddle.v2.dataset.mnist as mnist
19+
import numpy as np
20+
21+
22+
class TestDataBalance(unittest.TestCase):
23+
def prepare_data(self):
24+
def fake_data_generator():
25+
for n in xrange(self.total_ins_num):
26+
yield np.ones((3, 4)) * n, n
27+
28+
# Prepare data
29+
with fluid.program_guard(fluid.Program(), fluid.Program()):
30+
reader = paddle.batch(
31+
fake_data_generator, batch_size=self.batch_size)
32+
feeder = fluid.DataFeeder(
33+
feed_list=[
34+
fluid.layers.data(
35+
name='image', shape=[3, 4], dtype='float32'),
36+
fluid.layers.data(
37+
name='label', shape=[1], dtype='int64'),
38+
],
39+
place=fluid.CPUPlace())
40+
self.num_batches = fluid.recordio_writer.convert_reader_to_recordio_file(
41+
self.data_file_name, reader, feeder)
42+
43+
def prepare_lod_data(self):
44+
def fake_data_generator():
45+
for n in xrange(1, self.total_ins_num + 1):
46+
d1 = (np.ones((n, 3)) * n).astype('float32')
47+
d2 = (np.array(n).reshape((1, 1))).astype('int32')
48+
yield d1, d2
49+
50+
# Prepare lod data
51+
with fluid.program_guard(fluid.Program(), fluid.Program()):
52+
with fluid.recordio_writer.create_recordio_writer(
53+
filename=self.lod_data_file_name) as writer:
54+
eof = False
55+
generator = fake_data_generator()
56+
while (not eof):
57+
data_batch = [
58+
np.array([]).reshape((0, 3)), np.array([]).reshape(
59+
(0, 1))
60+
]
61+
lod = [0]
62+
for _ in xrange(self.batch_size):
63+
try:
64+
ins = generator.next()
65+
except StopIteration:
66+
eof = True
67+
break
68+
for i, d in enumerate(ins):
69+
data_batch[i] = np.concatenate(
70+
(data_batch[i], d), axis=0)
71+
lod.append(lod[-1] + ins[0].shape[0])
72+
if data_batch[0].shape[0] > 0:
73+
for i, d in enumerate(data_batch):
74+
t = fluid.LoDTensor()
75+
t.set(data_batch[i], fluid.CPUPlace())
76+
if i == 0:
77+
t.set_lod([lod])
78+
writer.append_tensor(t)
79+
writer.complete_append_tensor()
80+
81+
def setUp(self):
82+
self.use_cuda = fluid.core.is_compiled_with_cuda()
83+
self.data_file_name = './data_balance_test.recordio'
84+
self.lod_data_file_name = './data_balance_with_lod_test.recordio'
85+
self.total_ins_num = 50
86+
self.batch_size = 10
87+
self.prepare_data()
88+
self.prepare_lod_data()
89+
90+
def main(self):
91+
main_prog = fluid.Program()
92+
startup_prog = fluid.Program()
93+
with fluid.program_guard(main_prog, startup_prog):
94+
data_reader = fluid.layers.io.open_files(
95+
filenames=[self.data_file_name],
96+
shapes=[[-1, 3, 4], [-1, 1]],
97+
lod_levels=[0, 0],
98+
dtypes=['float32', 'int64'])
99+
if self.use_cuda:
100+
data_reader = fluid.layers.double_buffer(data_reader)
101+
image, label = fluid.layers.read_file(data_reader)
102+
103+
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
104+
exe = fluid.Executor(place)
105+
exe.run(startup_prog)
106+
107+
parallel_exe = fluid.ParallelExecutor(
108+
use_cuda=self.use_cuda, main_program=main_prog)
109+
110+
if (parallel_exe.device_count > self.batch_size):
111+
print("WARNING: Unittest TestDataBalance skipped. \
112+
For the result is not correct when device count \
113+
is larger than batch size.")
114+
exit(0)
115+
fetch_list = [image.name, label.name]
116+
117+
data_appeared = [False] * self.total_ins_num
118+
while (True):
119+
try:
120+
image_val, label_val = parallel_exe.run(fetch_list,
121+
return_numpy=True)
122+
except fluid.core.EnforceNotMet as ex:
123+
self.assertIn("There is no next data.", ex.message)
124+
break
125+
ins_num = image_val.shape[0]
126+
broadcasted_label = np.ones(
127+
(ins_num, 3, 4)) * label_val.reshape((ins_num, 1, 1))
128+
self.assertEqual(image_val.all(), broadcasted_label.all())
129+
for l in label_val:
130+
self.assertFalse(data_appeared[l[0]])
131+
data_appeared[l[0]] = True
132+
for i in data_appeared:
133+
self.assertTrue(i)
134+
135+
def main_lod(self):
136+
main_prog = fluid.Program()
137+
startup_prog = fluid.Program()
138+
with fluid.program_guard(main_prog, startup_prog):
139+
data_reader = fluid.layers.io.open_files(
140+
filenames=[self.lod_data_file_name],
141+
shapes=[[-1, 3], [-1, 1]],
142+
lod_levels=[1, 0],
143+
dtypes=['float32', 'int32'],
144+
thread_num=1)
145+
ins, label = fluid.layers.read_file(data_reader)
146+
147+
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
148+
exe = fluid.Executor(place)
149+
exe.run(startup_prog)
150+
151+
parallel_exe = fluid.ParallelExecutor(
152+
use_cuda=self.use_cuda, main_program=main_prog)
153+
154+
if (parallel_exe.device_count > self.batch_size):
155+
print("WARNING: Unittest TestDataBalance skipped. \
156+
For the result is not correct when device count \
157+
is larger than batch size.")
158+
exit(0)
159+
fetch_list = [ins.name, label.name]
160+
161+
data_appeared = [False] * self.total_ins_num
162+
while (True):
163+
try:
164+
ins_tensor, label_tensor = parallel_exe.run(
165+
fetch_list, return_numpy=False)
166+
except fluid.core.EnforceNotMet as ex:
167+
self.assertIn("There is no next data.", ex.message)
168+
break
169+
170+
ins_val = np.array(ins_tensor)
171+
label_val = np.array(label_tensor)
172+
ins_lod = ins_tensor.lod()[0]
173+
self.assertEqual(ins_val.shape[1], 3)
174+
self.assertEqual(label_val.shape[1], 1)
175+
self.assertEqual(len(ins_lod) - 1, label_val.shape[0])
176+
for i in range(0, len(ins_lod) - 1):
177+
ins_elem = ins_val[ins_lod[i]:ins_lod[i + 1]][:]
178+
label_elem = label_val[i][0]
179+
self.assertEqual(ins_elem.all(), label_elem.all())
180+
self.assertFalse(data_appeared[int(label_elem - 1)])
181+
data_appeared[int(label_elem - 1)] = True
182+
183+
for i in data_appeared:
184+
self.assertTrue(i)
185+
186+
def test_all(self):
187+
self.main()
188+
self.main_lod()

0 commit comments

Comments
 (0)