Skip to content

Commit c9641a0

Browse files
committed
refine code
1 parent ed7e74a commit c9641a0

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

paddle/operators/conv_op.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
7171
const framework::ExecutionContext& ctx) const {
7272
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
7373
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
74+
#ifdef PADDLE_WITH_CUDA
75+
if (platform::is_gpu_place(ctx.GetPlace())) {
76+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
77+
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
78+
}
79+
#endif
7480
framework::LibraryType library_;
7581
if (use_cudnn) {
7682
library_ = framework::LibraryType::kCUDNN;
@@ -285,6 +291,13 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
285291
const framework::ExecutionContext& ctx) const {
286292
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
287293
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
294+
#ifdef PADDLE_WITH_CUDA
295+
if (platform::is_gpu_place(ctx.GetPlace())) {
296+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
297+
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
298+
}
299+
#endif
300+
288301
framework::LibraryType library_;
289302
if (use_cudnn) {
290303
library_ = framework::LibraryType::kCUDNN;

paddle/operators/conv_transpose_op.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
6262
const framework::ExecutionContext& ctx) const {
6363
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
6464
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
65+
#ifdef PADDLE_WITH_CUDA
66+
if (platform::is_gpu_place(ctx.GetPlace())) {
67+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
68+
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
69+
}
70+
#endif
6571
framework::LibraryType library_;
6672
if (use_cudnn) {
6773
library_ = framework::LibraryType::kCUDNN;
@@ -265,6 +271,12 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
265271
const framework::ExecutionContext& ctx) const {
266272
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
267273
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
274+
#ifdef PADDLE_WITH_CUDA
275+
if (platform::is_gpu_place(ctx.GetPlace())) {
276+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
277+
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
278+
}
279+
#endif
268280
framework::LibraryType library_;
269281
if (use_cudnn) {
270282
library_ = framework::LibraryType::kCUDNN;

paddle/operators/pool_op.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
6565
const framework::ExecutionContext &ctx) const {
6666
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
6767
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
68+
#ifdef PADDLE_WITH_CUDA
69+
if (platform::is_gpu_place(ctx.GetPlace())) {
70+
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
71+
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
72+
}
73+
#endif
6874
framework::LibraryType library_;
6975
if (use_cudnn) {
7076
library_ = framework::LibraryType::kCUDNN;
@@ -90,6 +96,12 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
9096
const framework::ExecutionContext &ctx) const {
9197
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
9298
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
99+
#ifdef PADDLE_WITH_CUDA
100+
if (platform::is_gpu_place(ctx.GetPlace())) {
101+
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
102+
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
103+
}
104+
#endif
93105
framework::LibraryType library_;
94106
if (use_cudnn) {
95107
library_ = framework::LibraryType::kCUDNN;

0 commit comments

Comments
 (0)