Skip to content

Commit f403f69

Browse files
committed
Accelerate PADDLE_ENFORCE
test=release/1.2
1 parent 847cbdc commit f403f69

File tree

10 files changed

+109
-55
lines changed

10 files changed

+109
-55
lines changed

paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
120120
ClearFetchOp(graph_.get(), &fetch_ops);
121121
return fetches;
122122
}
123+
123124
void FastThreadedSSAGraphExecutor::RunOpAsync(
124125
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
125126
OpHandleBase *op,

paddle/fluid/framework/operator.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
163163
}
164164

165165
bool OperatorBase::HasInputs(const std::string& name) const {
166-
if (inputs_.find(name) != inputs_.end()) {
167-
return true;
168-
} else {
169-
return false;
170-
}
166+
return inputs_.find(name) != inputs_.end();
171167
}
172168

173169
std::string OperatorBase::Input(const std::string& name) const {

paddle/fluid/framework/operator.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ constexpr char kTempVarName[] = "@TEMP@";
4949
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
5050
constexpr char kGradVarSuffix[] = "@GRAD";
5151

52+
constexpr size_t kGradVarSuffixSize = 5U;
53+
5254
/// Variables with this suffix are supposed to be filled up with zeros.
5355
constexpr char kZeroVarSuffix[] = "@ZERO";
5456

@@ -60,7 +62,11 @@ constexpr char kNewGradSuffix[] = "@NEWGRAD@";
6062
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
6163

6264
inline std::string GradVarName(const std::string& var_name) {
63-
return var_name + kGradVarSuffix;
65+
std::string result;
66+
result.reserve(var_name.size() + kGradVarSuffixSize);
67+
result += var_name;
68+
result += kGradVarSuffix;
69+
return result;
6470
}
6571

6672
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
@@ -101,8 +107,8 @@ class OperatorBase {
101107
bool HasAttr(const std::string& name) const { return attrs_.count(name); }
102108
template <typename T>
103109
inline const T& Attr(const std::string& name) const {
104-
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
105-
name);
110+
PADDLE_ENFORCE(attrs_.find(name) != attrs_.end(),
111+
"%s should be in AttributeMap", name);
106112
return boost::get<T>(attrs_.at(name));
107113
}
108114
const AttributeMap& Attrs() const { return attrs_; }

paddle/fluid/inference/analysis/analyzer_tester.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,17 @@ void TestWord2vecPrediction(const std::string& model_path) {
6969
std::vector<PaddleTensor> outputs;
7070
CHECK(predictor->Run(slots, &outputs));
7171

72-
PADDLE_ENFORCE(outputs.size(), 1UL);
72+
PADDLE_ENFORCE_EQ(outputs.size(), 1UL);
7373
// Check the output buffer size and result of each tid.
74-
PADDLE_ENFORCE(outputs.front().data.length(), 33168UL);
74+
PADDLE_ENFORCE_EQ(outputs.front().data.length(), 33168UL);
7575
float result[5] = {0.00129761, 0.00151112, 0.000423564, 0.00108815,
7676
0.000932706};
7777
const size_t num_elements = outputs.front().data.length() / sizeof(float);
7878
// The outputs' buffers are in CPU memory.
7979
for (size_t i = 0; i < std::min(static_cast<size_t>(5UL), num_elements);
8080
i++) {
81-
LOG(INFO) << "data: "
82-
<< static_cast<float*>(outputs.front().data.data())[i];
81+
LOG(INFO) << "data: " << static_cast<float*>(outputs.front().data.data())[i]
82+
<< " result: " << result[i];
8383
PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i],
8484
result[i]);
8585
}

paddle/fluid/operators/detail/safe_ref.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace detail {
2525
*/
2626
template <typename T, typename... ARGS>
2727
inline T& Ref(T* ptr, ARGS&&... args) {
28-
PADDLE_ENFORCE(ptr != nullptr, args...);
28+
PADDLE_ENFORCE(ptr != nullptr, ::paddle::string::Sprintf(args...));
2929
return *ptr;
3030
}
3131

paddle/fluid/operators/distributed/proto_encoder_helper.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ class ProtoEncodeHelper {
8484
~ProtoEncodeHelper() {
8585
#define REPLACE_ENFORCE_GLOG 1
8686
// Make sure callers didn't do operations that went over max_size promised
87-
paddle::platform::throw_on_error(p_ <= limit_);
87+
if (paddle::platform::is_error(p_ <= limit_)) {
88+
paddle::platform::throw_on_error(p_ <= limit_);
89+
}
8890
#undef REPLACE_ENFORCE_GLOG
8991
}
9092

paddle/fluid/operators/lrn_mkldnn_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ template <typename T>
5050
class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
5151
public:
5252
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
53-
PADDLE_ENFORCE(std::is_same<T, float>::value,
54-
"MKLDNN LRN must use float data.");
53+
const bool is_float_type = std::is_same<T, float>::value;
54+
PADDLE_ENFORCE(is_float_type, "MKLDNN LRN must use float data.");
5555
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
5656
"MKLDNN LRN must use CPUPlace.");
5757

@@ -132,8 +132,8 @@ template <typename T>
132132
class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
133133
public:
134134
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
135-
PADDLE_ENFORCE(std::is_same<T, float>::value,
136-
"MKLDNN LRN must use float data.");
135+
const bool is_float_type = std::is_same<T, float>::value;
136+
PADDLE_ENFORCE(is_float_type, "MKLDNN LRN must use float data.");
137137
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
138138
"MKLDNN LRN must use CPUPlace.");
139139
PADDLE_ENFORCE(

paddle/fluid/platform/enforce.h

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -131,68 +131,72 @@ struct EOFException : public std::exception {
131131
#define LIKELY(condition) (condition)
132132
#endif
133133

134+
inline bool is_error(bool stat) { return !stat; }
135+
134136
template <typename... Args>
135137
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
136138
bool stat, const Args&... args) {
137-
if (UNLIKELY(!(stat))) {
138139
#ifndef REPLACE_ENFORCE_GLOG
139-
throw std::runtime_error(string::Sprintf(args...));
140+
throw std::runtime_error(string::Sprintf(args...));
140141
#else
141-
LOG(FATAL) << string::Sprintf(args...);
142+
LOG(FATAL) << string::Sprintf(args...);
142143
#endif
143-
}
144144
}
145145

146146
#ifdef PADDLE_WITH_CUDA
147147

148+
inline bool is_error(cudaError_t e) { return UNLIKELY(e); }
149+
148150
template <typename... Args>
149151
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
150152
cudaError_t e, const Args&... args) {
151-
if (UNLIKELY(e)) {
152153
#ifndef REPLACE_ENFORCE_GLOG
153-
throw thrust::system_error(e, thrust::cuda_category(),
154-
string::Sprintf(args...));
154+
throw thrust::system_error(e, thrust::cuda_category(),
155+
string::Sprintf(args...));
155156
#else
156-
LOG(FATAL) << string::Sprintf(args...);
157+
LOG(FATAL) << string::Sprintf(args...);
157158
#endif
158-
}
159+
}
160+
161+
inline bool is_error(curandStatus_t stat) {
162+
return stat != CURAND_STATUS_SUCCESS;
159163
}
160164

161165
template <typename... Args>
162166
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
163167
curandStatus_t stat, const Args&... args) {
164-
if (stat != CURAND_STATUS_SUCCESS) {
165168
#ifndef REPLACE_ENFORCE_GLOG
166-
throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
167-
string::Sprintf(args...));
169+
throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
170+
string::Sprintf(args...));
168171
#else
169-
LOG(FATAL) << string::Sprintf(args...);
172+
LOG(FATAL) << string::Sprintf(args...);
170173
#endif
171-
}
174+
}
175+
176+
inline bool is_error(cudnnStatus_t stat) {
177+
return stat != CUDNN_STATUS_SUCCESS;
172178
}
173179

174180
template <typename... Args>
175181
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
176182
cudnnStatus_t stat, const Args&... args) {
177-
if (stat == CUDNN_STATUS_SUCCESS) {
178-
return;
179-
} else {
180183
#ifndef REPLACE_ENFORCE_GLOG
181-
throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) +
182-
string::Sprintf(args...));
184+
throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) +
185+
string::Sprintf(args...));
183186
#else
184-
LOG(FATAL) << string::Sprintf(args...);
187+
LOG(FATAL) << string::Sprintf(args...);
185188
#endif
186-
}
189+
}
190+
191+
inline bool is_error(cublasStatus_t stat) {
192+
return stat != CUBLAS_STATUS_SUCCESS;
187193
}
188194

189195
template <typename... Args>
190196
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
191197
cublasStatus_t stat, const Args&... args) {
192198
std::string err;
193-
if (stat == CUBLAS_STATUS_SUCCESS) {
194-
return;
195-
} else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) {
199+
if (stat == CUBLAS_STATUS_NOT_INITIALIZED) {
196200
err = "CUBLAS: not initialized, ";
197201
} else if (stat == CUBLAS_STATUS_ALLOC_FAILED) {
198202
err = "CUBLAS: alloc failed, ";
@@ -219,20 +223,18 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
219223
}
220224

221225
#if !defined(__APPLE__) && !defined(_WIN32)
226+
inline bool is_error(ncclResult_t stat) { return stat != ncclSuccess; }
227+
222228
template <typename... Args>
223229
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
224230
ncclResult_t stat, const Args&... args) {
225-
if (stat == ncclSuccess) {
226-
return;
227-
} else {
228231
#ifndef REPLACE_ENFORCE_GLOG
229-
throw std::runtime_error(platform::dynload::ncclGetErrorString(stat) +
230-
string::Sprintf(args...));
232+
throw std::runtime_error(platform::dynload::ncclGetErrorString(stat) +
233+
string::Sprintf(args...));
231234
#else
232-
LOG(FATAL) << platform::dynload::ncclGetErrorString(stat)
233-
<< string::Sprintf(args...);
235+
LOG(FATAL) << platform::dynload::ncclGetErrorString(stat)
236+
<< string::Sprintf(args...);
234237
#endif
235-
}
236238
}
237239
#endif // __APPLE__ and windows
238240
#endif // PADDLE_WITH_CUDA
@@ -250,21 +252,49 @@ inline void throw_on_error(T e) {
250252
__FILE__, __LINE__); \
251253
} while (false)
252254

255+
#define __PADDLE_THROW_ERROR_I(_, _9, _8, _7, _6, _5, _4, _3, _2, X_, ...) X_;
256+
257+
#define __THROW_ON_ERROR_ONE_ARG(COND, ARG) \
258+
::paddle::platform::throw_on_error(COND, ::paddle::string::Sprintf(ARG));
259+
260+
#define __PADDLE_THROW_ON_ERROR(COND, ...) \
261+
__PADDLE_THROW_ERROR_I( \
262+
__VA_ARGS__, ::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
263+
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
264+
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
265+
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
266+
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
267+
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
268+
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
269+
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
270+
__THROW_ON_ERROR_ONE_ARG(COND, __VA_ARGS__))
271+
272+
#define __PADDLE_UNARY_COMPARE(COND, ...) \
273+
do { \
274+
auto __cond = COND; \
275+
if (UNLIKELY(::paddle::platform::is_error(__cond))) { \
276+
__PADDLE_THROW_ON_ERROR(__cond, __VA_ARGS__); \
277+
} \
278+
} while (0)
279+
253280
#ifndef REPLACE_ENFORCE_GLOG
254-
#define PADDLE_ENFORCE(...) \
281+
#define __PADDLE_ENFORCE_I(COND, ...) \
255282
do { \
256283
try { \
257-
::paddle::platform::throw_on_error(__VA_ARGS__); \
284+
__PADDLE_UNARY_COMPARE(COND, __VA_ARGS__); \
258285
} catch (...) { \
259286
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
260287
__FILE__, __LINE__); \
261288
} \
262-
} while (false)
289+
} while (0)
263290

264291
#else
265-
#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__);
292+
#define __PADDLE_ENFORCE_I(COND, ...) __PADDLE_UNARY_COMPARE(COND, __VA_ARGS__);
266293
#endif // REPLACE_ENFORCE_GLOG
267294

295+
#define __PADDLE_ENFORCE(__args) __PADDLE_ENFORCE_I __args
296+
#define PADDLE_ENFORCE(...) __PADDLE_ENFORCE((__VA_ARGS__))
297+
268298
#define PADDLE_THROW_EOF() \
269299
do { \
270300
throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \

paddle/fluid/platform/enforce_test.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,25 @@ TEST(ENFORCE, FAILED) {
3737
HasPrefix(StringPiece(error.what()), "Enforce is not ok 123 at all"));
3838
}
3939
EXPECT_TRUE(caught_exception);
40+
41+
caught_exception = false;
42+
try {
43+
PADDLE_ENFORCE(false, "Enforce is not ok at all");
44+
} catch (paddle::platform::EnforceNotMet error) {
45+
caught_exception = true;
46+
EXPECT_TRUE(
47+
HasPrefix(StringPiece(error.what()), "Enforce is not ok at all"));
48+
}
49+
EXPECT_TRUE(caught_exception);
50+
51+
caught_exception = false;
52+
try {
53+
PADDLE_ENFORCE(false);
54+
} catch (paddle::platform::EnforceNotMet error) {
55+
caught_exception = true;
56+
EXPECT_NE(std::string(error.what()).find(" at "), 0);
57+
}
58+
EXPECT_TRUE(caught_exception);
4059
}
4160

4261
TEST(ENFORCE, NO_ARG_OK) {

paddle/fluid/string/printf.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ void Fprintf(std::ostream& out, const char* fmt, const Args&... args) {
8787
template <typename... Args>
8888
std::string Sprintf(const Args&... args) {
8989
std::ostringstream oss;
90-
Fprintf(oss, "");
90+
Fprintf(oss, "%s", args...);
9191
return oss.str();
9292
}
9393

0 commit comments

Comments
 (0)