Skip to content

Commit 16fc5e3

Browse files
committed
refine cmake for cudnn
1 parent 95ea54f commit 16fc5e3

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ function(op_library TARGET)
1111
set(cc_srcs)
1212
set(cu_srcs)
1313
set(cu_cc_srcs)
14+
set(cudnn_cu_cc_srcs)
1415
set(op_common_deps operator op_registry math_function)
1516
set(options "")
1617
set(oneValueArgs "")
@@ -34,6 +35,8 @@ function(op_library TARGET)
3435
foreach(src ${op_library_SRCS})
3536
if (${src} MATCHES ".*\\.cu$")
3637
list(APPEND cu_srcs ${src})
38+
elseif(${src} MATCHES ".*_cudnn_op.cu.cc$")
39+
list(APPEND cudnn_cu_cc_srcs ${src})
3740
elseif(${src} MATCHES ".*\\.cu.cc$")
3841
list(APPEND cu_cc_srcs ${src})
3942
elseif(${src} MATCHES ".*\\.cc$")
@@ -54,7 +57,7 @@ function(op_library TARGET)
5457
set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
5558
endif()
5659
if (WITH_GPU)
57-
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
60+
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
5861
${op_common_deps})
5962
else()
6063
cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS}
@@ -98,6 +101,12 @@ function(op_library TARGET)
98101
set(pybind_flag 1)
99102
endif()
100103

104+
# pybind USE_OP_DEVICE_KERNEL for CUDNN
105+
list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len)
106+
if (${cudnn_cu_cc_srcs_len} GREATER 0)
107+
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
108+
endif()
109+
101110
# pybind USE_OP
102111
if (${pybind_flag} EQUAL 0)
103112
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
@@ -159,21 +168,16 @@ op_library(create_reader_op DEPS reader)
159168

160169
# Regist multiple Kernel to pybind
161170
if (WITH_GPU)
162-
163-
op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS
164-
vol2col depthwise_conv)
165-
166-
op_library(edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function)
167-
op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling)
168-
op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
169-
conv_transpose_cudnn_op.cu.cc DEPS vol2col)
170-
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d, CUDNN);\n")
171-
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(pool2d, CUDNN);\n")
172-
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d_transpose, CUDNN);\n")
171+
op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS
172+
vol2col depthwise_conv)
173+
op_library(edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function)
174+
op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling)
175+
op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
176+
conv_transpose_cudnn_op.cu.cc DEPS vol2col)
173177
else()
174-
op_library(conv_op SRCS conv_op.cc DEPS vol2col)
175-
op_library(pool_op SRCS pool_op.cc DEPS pooling)
176-
op_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col)
178+
op_library(conv_op SRCS conv_op.cc DEPS vol2col)
179+
op_library(pool_op SRCS pool_op.cc DEPS pooling)
180+
op_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col)
177181
endif()
178182

179183
cc_library(batch_size_like SRCS batch_size_like.cc DEPS op_registry)

0 commit comments

Comments
 (0)