File tree Expand file tree Collapse file tree 3 files changed +37
-0
lines changed Expand file tree Collapse file tree 3 files changed +37
-0
lines changed Original file line number Diff line number Diff line change @@ -71,6 +71,12 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
71
71
const framework::ExecutionContext& ctx) const {
72
72
bool use_cudnn = ctx.Attr <bool >(" use_cudnn" );
73
73
use_cudnn &= platform::is_gpu_place (ctx.GetPlace ());
74
+ #ifdef PADDLE_WITH_CUDA
75
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
76
+ auto & dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
77
+ use_cudnn &= dev_ctx.cudnn_handle () != nullptr ;
78
+ }
79
+ #endif
74
80
framework::LibraryType library_;
75
81
if (use_cudnn) {
76
82
library_ = framework::LibraryType::kCUDNN ;
@@ -285,6 +291,13 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
285
291
const framework::ExecutionContext& ctx) const {
286
292
bool use_cudnn = ctx.Attr <bool >(" use_cudnn" );
287
293
use_cudnn &= platform::is_gpu_place (ctx.GetPlace ());
294
+ #ifdef PADDLE_WITH_CUDA
295
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
296
+ auto & dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
297
+ use_cudnn &= dev_ctx.cudnn_handle () != nullptr ;
298
+ }
299
+ #endif
300
+
288
301
framework::LibraryType library_;
289
302
if (use_cudnn) {
290
303
library_ = framework::LibraryType::kCUDNN ;
Original file line number Diff line number Diff line change @@ -62,6 +62,12 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
62
62
const framework::ExecutionContext& ctx) const {
63
63
bool use_cudnn = ctx.Attr <bool >(" use_cudnn" );
64
64
use_cudnn &= platform::is_gpu_place (ctx.GetPlace ());
65
+ #ifdef PADDLE_WITH_CUDA
66
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
67
+ auto & dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
68
+ use_cudnn &= dev_ctx.cudnn_handle () != nullptr ;
69
+ }
70
+ #endif
65
71
framework::LibraryType library_;
66
72
if (use_cudnn) {
67
73
library_ = framework::LibraryType::kCUDNN ;
@@ -265,6 +271,12 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
265
271
const framework::ExecutionContext& ctx) const {
266
272
bool use_cudnn = ctx.Attr <bool >(" use_cudnn" );
267
273
use_cudnn &= platform::is_gpu_place (ctx.GetPlace ());
274
+ #ifdef PADDLE_WITH_CUDA
275
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
276
+ auto & dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
277
+ use_cudnn &= dev_ctx.cudnn_handle () != nullptr ;
278
+ }
279
+ #endif
268
280
framework::LibraryType library_;
269
281
if (use_cudnn) {
270
282
library_ = framework::LibraryType::kCUDNN ;
Original file line number Diff line number Diff line change @@ -65,6 +65,12 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
65
65
const framework::ExecutionContext &ctx) const {
66
66
bool use_cudnn = ctx.Attr <bool >(" use_cudnn" );
67
67
use_cudnn &= platform::is_gpu_place (ctx.GetPlace ());
68
+ #ifdef PADDLE_WITH_CUDA
69
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
70
+ auto &dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
71
+ use_cudnn &= dev_ctx.cudnn_handle () != nullptr ;
72
+ }
73
+ #endif
68
74
framework::LibraryType library_;
69
75
if (use_cudnn) {
70
76
library_ = framework::LibraryType::kCUDNN ;
@@ -90,6 +96,12 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
90
96
const framework::ExecutionContext &ctx) const {
91
97
bool use_cudnn = ctx.Attr <bool >(" use_cudnn" );
92
98
use_cudnn &= platform::is_gpu_place (ctx.GetPlace ());
99
+ #ifdef PADDLE_WITH_CUDA
100
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
101
+ auto &dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
102
+ use_cudnn &= dev_ctx.cudnn_handle () != nullptr ;
103
+ }
104
+ #endif
93
105
framework::LibraryType library_;
94
106
if (use_cudnn) {
95
107
library_ = framework::LibraryType::kCUDNN ;
You can’t perform that action at this time.
0 commit comments