Skip to content

Commit ce16b40

Browse files
authored
Merge pull request #11891 from JiayiFeng/dev_eof_exp
Add EOFException to represent EOF in C++ reader
2 parents 037ce12 + ed4b247 commit ce16b40

File tree

11 files changed

+54
-16
lines changed

11 files changed

+54
-16
lines changed

paddle/fluid/framework/details/data_balance_op_handle.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
6262
}
6363
if (total_size < device_num) {
6464
// No enough data.
65-
PADDLE_THROW("There is no next data.");
65+
PADDLE_THROW_EOF();
6666
}
6767
std::sort(size_device_vec.begin(), size_device_vec.end(),
6868
[](const std::array<int, 2> &a, const std::array<int, 2> &b) {

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
9898
if (timeout) {
9999
std::lock_guard<std::mutex> l(exception_mu_);
100100
if (exception_) {
101-
auto exp = *exception_;
102-
exception_.reset();
103-
throw exp;
101+
std::exception *exp = exception_.get();
102+
if (dynamic_cast<platform::EOFException *>(exp)) {
103+
auto e = *static_cast<platform::EOFException *>(exp);
104+
exception_.reset();
105+
throw e;
106+
} else if (dynamic_cast<platform::EnforceNotMet *>(exp)) {
107+
auto e = *static_cast<platform::EnforceNotMet *>(exp);
108+
exception_.reset();
109+
throw e;
110+
} else {
111+
LOG(FATAL) << "Unknown exception.";
112+
}
104113
} else {
105114
continue;
106115
}
@@ -199,6 +208,12 @@ void ThreadedSSAGraphExecutor::RunOp(
199208
running_ops_--;
200209
ready_var_q->Extend(op->Outputs());
201210
VLOG(10) << op << " " << op->Name() << "Signal posted";
211+
} catch (platform::EOFException ex) {
212+
std::lock_guard<std::mutex> l(exception_mu_);
213+
// EOFException will not cover up existing EnforceNotMet.
214+
if (exception_.get() == nullptr) {
215+
exception_.reset(new platform::EOFException(ex));
216+
}
202217
} catch (platform::EnforceNotMet ex) {
203218
std::lock_guard<std::mutex> l(exception_mu_);
204219
exception_.reset(new platform::EnforceNotMet(ex));

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
5757
std::vector<platform::Place> places_;
5858
platform::DeviceContextPool fetch_ctxs_;
5959
std::mutex exception_mu_;
60-
std::unique_ptr<platform::EnforceNotMet> exception_;
60+
std::unique_ptr<std::exception> exception_;
6161
std::atomic<int> running_ops_;
6262

6363
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,

paddle/fluid/operators/read_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class ReadOp : public framework::OperatorBase {
6868
reader->ReadNext(&ins);
6969
if (ins.empty()) {
7070
if (Attr<bool>("throw_eof_exp")) {
71-
PADDLE_THROW("There is no next data.");
71+
PADDLE_THROW_EOF();
7272
} else {
7373
ins.resize(out_arg_names.size());
7474
for (auto& tensor : ins) {

paddle/fluid/platform/enforce.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@ struct EnforceNotMet : public std::exception {
102102
const char* what() const noexcept { return err_str_.c_str(); }
103103
};
104104

105+
struct EOFException : public std::exception {
106+
std::string err_str_;
107+
EOFException(const char* err_msg, const char* f, int l) {
108+
err_str_ = string::Sprintf("%s at [%s:%d]", err_msg, f, l);
109+
}
110+
111+
const char* what() const noexcept { return err_str_.c_str(); }
112+
};
113+
105114
// Because most enforce conditions would evaluate to true, we can use
106115
// __builtin_expect to instruct the C++ compiler to generate code that
107116
// always forces branch prediction of true.
@@ -242,6 +251,11 @@ inline void throw_on_error(T e) {
242251
#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__);
243252
#endif
244253

254+
#define PADDLE_THROW_EOF() \
255+
do { \
256+
throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \
257+
__LINE__); \
258+
} while (false)
245259
/*
246260
* Some enforce helpers here, usage:
247261
* int a = 1;

paddle/fluid/platform/enforce_test.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,14 @@ TEST(ENFORCE_USER_DEFINED_CLASS, NE) {
210210
Dims a{{1, 2, 3, 4}}, b{{5, 6, 7, 8}};
211211
ASSERT_THROW(PADDLE_ENFORCE_EQ(a, b), paddle::platform::EnforceNotMet);
212212
}
213+
214+
TEST(EOF_EXCEPTION, THROW_EOF) {
215+
bool caught_eof = false;
216+
try {
217+
PADDLE_THROW_EOF();
218+
} catch (paddle::platform::EOFException error) {
219+
caught_eof = true;
220+
EXPECT_TRUE(HasPrefix(StringPiece(error.what()), "There is no next data."));
221+
}
222+
EXPECT_TRUE(caught_eof);
223+
}

paddle/fluid/pybind/exception.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@ namespace paddle {
1818
namespace pybind {
1919

2020
void BindException(pybind11::module* m) {
21+
static pybind11::exception<platform::EOFException> eof(*m, "EOFException");
2122
static pybind11::exception<platform::EnforceNotMet> exc(*m, "EnforceNotMet");
2223
pybind11::register_exception_translator([](std::exception_ptr p) {
2324
try {
2425
if (p) std::rethrow_exception(p);
26+
} catch (const platform::EOFException& e) {
27+
eof(e.what());
2528
} catch (const platform::EnforceNotMet& e) {
2629
exc(e.what());
2730
}

python/paddle/fluid/tests/unittests/test_data_balance.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,7 @@ def main(self):
118118
try:
119119
image_val, label_val = parallel_exe.run(fetch_list,
120120
return_numpy=True)
121-
except fluid.core.EnforceNotMet as ex:
122-
self.assertIn("There is no next data.", ex.message)
121+
except fluid.core.EOFException:
123122
break
124123
ins_num = image_val.shape[0]
125124
broadcasted_label = np.ones(
@@ -162,8 +161,7 @@ def main_lod(self):
162161
try:
163162
ins_tensor, label_tensor = parallel_exe.run(
164163
fetch_list, return_numpy=False)
165-
except fluid.core.EnforceNotMet as ex:
166-
self.assertIn("There is no next data.", ex.message)
164+
except fluid.core.EOFException:
167165
break
168166

169167
ins_val = np.array(ins_tensor)

python/paddle/fluid/tests/unittests/test_multi_file_reader.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ def main(self, thread_num):
6464
while True:
6565
try:
6666
img_val, = exe.run(fetch_list=[img])
67-
except fluid.core.EnforceNotMet as ex:
68-
self.assertIn("There is no next data.", ex.message)
67+
except fluid.core.EOFException:
6968
break
7069
batch_count += 1
7170
self.assertLessEqual(img_val.shape[0], self.batch_size)

python/paddle/fluid/tests/unittests/test_multi_pass_reader.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def test_main(self):
5959
while True:
6060
try:
6161
img_val, = exe.run(fetch_list=[img])
62-
except fluid.core.EnforceNotMet as ex:
63-
self.assertIn("There is no next data.", ex.message)
62+
except fluid.core.EOFException:
6463
break
6564
batch_count += 1
6665
self.assertLessEqual(img_val.shape[0], self.batch_size)

0 commit comments

Comments
 (0)