Skip to content

Commit 38863a2

Browse files
authored
Merge pull request #12454 from JiayiFeng/dev_exception_holder
Exception Holder
2 parents 56b50ee + bc1b7b9 commit 38863a2

File tree

3 files changed

+90
-24
lines changed

3 files changed

+90
-24
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/platform/enforce.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace details {
22+
23+
class ExceptionHolder {
24+
public:
25+
void Catch(const platform::EnforceNotMet& exp) {
26+
std::lock_guard<std::mutex> lock(mu_);
27+
exception_.reset(new platform::EnforceNotMet(exp));
28+
type_ = kEnforceNotMet;
29+
}
30+
31+
void Catch(const platform::EOFException& exp) {
32+
std::lock_guard<std::mutex> lock(mu_);
33+
// EOFException will not cover up existing EnforceNotMet.
34+
if (exception_.get() == nullptr) {
35+
exception_.reset(new platform::EOFException(exp));
36+
type_ = kEOF;
37+
}
38+
}
39+
40+
bool ExceptionCatched() const {
41+
std::lock_guard<std::mutex> lock(mu_);
42+
return exception_.get() != nullptr;
43+
}
44+
45+
void Throw() {
46+
std::lock_guard<std::mutex> lock(mu_);
47+
switch (type_) {
48+
case kNone:
49+
break;
50+
case kEnforceNotMet: {
51+
auto e = *static_cast<platform::EnforceNotMet*>(exception_.get());
52+
throw e;
53+
break;
54+
}
55+
case kEOF: {
56+
auto e = *static_cast<platform::EOFException*>(exception_.get());
57+
throw e;
58+
break;
59+
}
60+
default:
61+
LOG(FATAL) << "Unknown exception.";
62+
}
63+
exception_.reset();
64+
type_ = kNone;
65+
}
66+
67+
void Clear() {
68+
std::lock_guard<std::mutex> lock(mu_);
69+
exception_.reset();
70+
type_ = kNone;
71+
}
72+
73+
private:
74+
enum ExceptionType { kNone, kEnforceNotMet, kEOF };
75+
ExceptionType type_{kNone};
76+
77+
std::unique_ptr<std::exception> exception_;
78+
mutable std::mutex mu_;
79+
};
80+
81+
} // namespace details
82+
} // namespace framework
83+
} // namespace paddle

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
8383

8484
// Clean run context
8585
run_op_futures_.clear();
86-
exception_.reset();
86+
exception_holder_.Clear();
8787

8888
// Step 3. Execution
8989
while (!pending_vars.empty()) {
@@ -103,23 +103,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
103103
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
104104

105105
if (timeout) {
106-
std::unique_lock<std::mutex> l(exception_mu_);
107-
if (exception_) {
108-
l.unlock();
106+
if (exception_holder_.ExceptionCatched()) {
109107
for (auto &run_op_future : run_op_futures_) {
110108
run_op_future.wait();
111109
}
112-
l.lock();
113-
std::exception *exp = exception_.get();
114-
if (dynamic_cast<platform::EOFException *>(exp)) {
115-
auto e = *static_cast<platform::EOFException *>(exp);
116-
throw e;
117-
} else if (dynamic_cast<platform::EnforceNotMet *>(exp)) {
118-
auto e = *static_cast<platform::EnforceNotMet *>(exp);
119-
throw e;
120-
} else {
121-
LOG(FATAL) << "Unknown exception.";
122-
}
110+
exception_holder_.Throw();
123111
} else {
124112
continue;
125113
}
@@ -229,14 +217,9 @@ void ThreadedSSAGraphExecutor::RunOp(
229217
ready_var_q->Extend(op->Outputs());
230218
VLOG(10) << op << " " << op->Name() << "Signal posted";
231219
} catch (platform::EOFException ex) {
232-
std::lock_guard<std::mutex> l(exception_mu_);
233-
// EOFException will not cover up existing EnforceNotMet.
234-
if (exception_.get() == nullptr) {
235-
exception_.reset(new platform::EOFException(ex));
236-
}
220+
exception_holder_.Catch(ex);
237221
} catch (platform::EnforceNotMet ex) {
238-
std::lock_guard<std::mutex> l(exception_mu_);
239-
exception_.reset(new platform::EnforceNotMet(ex));
222+
exception_holder_.Catch(ex);
240223
} catch (...) {
241224
LOG(FATAL) << "Unknown exception catched";
242225
}

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <functional>
2525
#include "ThreadPool.h" // ThreadPool in thrird party
2626
#include "paddle/fluid/framework/blocking_queue.h"
27+
#include "paddle/fluid/framework/details/exception_holder.h"
2728
#include "paddle/fluid/framework/details/execution_strategy.h"
2829
#include "paddle/fluid/framework/details/fetch_op_handle.h"
2930
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
@@ -59,8 +60,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
5960
std::vector<Scope *> local_scopes_;
6061
std::vector<platform::Place> places_;
6162
platform::DeviceContextPool fetch_ctxs_;
62-
std::mutex exception_mu_;
63-
std::unique_ptr<std::exception> exception_;
63+
ExceptionHolder exception_holder_;
6464
std::atomic<int> running_ops_;
6565

6666
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,

0 commit comments

Comments
 (0)