Skip to content

Commit 1ac31d3

Browse files
authored
Merge pull request #8591 from chengduoZH/feature/refine_cmake_for_cudnn
Refine cmake for cudnn op
2 parents 4948f7b + 62fe2f2 commit 1ac31d3

File tree

1 file changed

+23
-28
lines changed

1 file changed

+23
-28
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ function(op_library TARGET)
1111
set(cc_srcs)
1212
set(cu_srcs)
1313
set(cu_cc_srcs)
14+
set(cudnn_cu_cc_srcs)
15+
set(CUDNN_FILE)
1416
set(op_common_deps operator op_registry math_function)
1517
set(options "")
1618
set(oneValueArgs "")
@@ -30,10 +32,16 @@ function(op_library TARGET)
3032
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
3133
list(APPEND cu_srcs ${TARGET}.cu)
3234
endif()
35+
string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}")
36+
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc)
37+
list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc)
38+
endif()
3339
else()
3440
foreach(src ${op_library_SRCS})
3541
if (${src} MATCHES ".*\\.cu$")
3642
list(APPEND cu_srcs ${src})
43+
elseif(${src} MATCHES ".*_cudnn_op.cu.cc$")
44+
list(APPEND cudnn_cu_cc_srcs ${src})
3745
elseif(${src} MATCHES ".*\\.cu.cc$")
3846
list(APPEND cu_cc_srcs ${src})
3947
elseif(${src} MATCHES ".*\\.cc$")
@@ -54,7 +62,7 @@ function(op_library TARGET)
5462
set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
5563
endif()
5664
if (WITH_GPU)
57-
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
65+
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
5866
${op_common_deps})
5967
else()
6068
cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS}
@@ -98,6 +106,12 @@ function(op_library TARGET)
98106
set(pybind_flag 1)
99107
endif()
100108

109+
# pybind USE_OP_DEVICE_KERNEL for CUDNN
110+
list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len)
111+
if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0)
112+
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
113+
endif()
114+
101115
# pybind USE_OP
102116
if (${pybind_flag} EQUAL 0)
103117
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
@@ -152,43 +166,24 @@ op_library(lstm_op DEPS sequence2batch lstm_compute)
152166
op_library(lstmp_op DEPS sequence2batch lstm_compute)
153167
op_library(gru_op DEPS sequence2batch gru_compute)
154168
op_library(recurrent_op DEPS executor)
155-
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function)
169+
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
156170
op_library(cos_sim_op DEPS cos_sim_functor)
157171
op_library(parallel_do_op DEPS executor)
158172
op_library(create_reader_op DEPS reader)
159173

160174
# Regist multiple Kernel to pybind
161175
if (WITH_GPU)
162-
163-
op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS
164-
vol2col depthwise_conv)
165-
166-
op_library(edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function)
167-
op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling)
168-
op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
169-
conv_transpose_cudnn_op.cu.cc DEPS vol2col)
170-
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d, CUDNN);\n")
171-
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(pool2d, CUDNN);\n")
172-
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d_transpose, CUDNN);\n")
176+
op_library(conv_op DEPS vol2col depthwise_conv)
173177
else()
174-
op_library(conv_op SRCS conv_op.cc DEPS vol2col)
175-
op_library(pool_op SRCS pool_op.cc DEPS pooling)
176-
op_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col)
178+
op_library(conv_op DEPS vol2col)
177179
endif()
180+
op_library(pool_op DEPS pooling)
181+
op_library(conv_transpose_op DEPS vol2col)
178182

179183
cc_library(batch_size_like SRCS batch_size_like.cc DEPS op_registry)
180-
181-
op_library(fill_constant_batch_size_like_op
182-
SRCS fill_constant_batch_size_like_op.cc fill_constant_batch_size_like_op.cu.cc
183-
DEPS batch_size_like)
184-
185-
op_library(uniform_random_batch_size_like_op
186-
SRCS uniform_random_batch_size_like_op.cc
187-
DEPS batch_size_like uniform_random_op)
188-
189-
op_library(gaussian_random_batch_size_like_op
190-
SRCS gaussian_random_batch_size_like_op.cc
191-
DEPS batch_size_like gaussian_random_op)
184+
op_library(fill_constant_batch_size_like_op DEPS batch_size_like)
185+
op_library(uniform_random_batch_size_like_op DEPS batch_size_like uniform_random_op)
186+
op_library(gaussian_random_batch_size_like_op DEPS batch_size_like gaussian_random_op)
192187

193188
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
194189
op_library(save_op DEPS lod_tensor)

0 commit comments

Comments
 (0)