@@ -11,6 +11,7 @@ function(op_library TARGET)
11
11
set (cc_srcs )
12
12
set (cu_srcs )
13
13
set (cu_cc_srcs )
14
+ set (cudnn_cu_cc_srcs )
14
15
set (op_common_deps operator op_registry math_function )
15
16
set (options "" )
16
17
set (oneValueArgs "" )
@@ -34,6 +35,8 @@ function(op_library TARGET)
34
35
foreach (src ${op_library_SRCS} )
35
36
if (${src} MATCHES ".*\\ .cu$" )
36
37
list (APPEND cu_srcs ${src} )
38
+ elseif (${src} MATCHES ".*_cudnn_op.cu.cc$" )
39
+ list (APPEND cudnn_cu_cc_srcs ${src} )
37
40
elseif (${src} MATCHES ".*\\ .cu.cc$" )
38
41
list (APPEND cu_cc_srcs ${src} )
39
42
elseif (${src} MATCHES ".*\\ .cc$" )
@@ -54,7 +57,7 @@ function(op_library TARGET)
54
57
set (DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE )
55
58
endif ()
56
59
if (WITH_GPU )
57
- nv_library (${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
60
+ nv_library (${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${ cu_srcs} DEPS ${op_library_DEPS}
58
61
${op_common_deps} )
59
62
else ()
60
63
cc_library (${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS}
@@ -98,6 +101,12 @@ function(op_library TARGET)
98
101
set (pybind_flag 1 )
99
102
endif ()
100
103
104
+ # pybind USE_OP_DEVICE_KERNEL for CUDNN
105
+ list (LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len )
106
+ if (${cudnn_cu_cc_srcs_len} GREATER 0 )
107
+ file (APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET} , CUDNN);\n " )
108
+ endif ()
109
+
101
110
# pybind USE_OP
102
111
if (${pybind_flag} EQUAL 0 )
103
112
file (APPEND ${pybind_file} "USE_OP(${TARGET} );\n " )
@@ -159,21 +168,16 @@ op_library(create_reader_op DEPS reader)
159
168
160
169
# Regist multiple Kernel to pybind
161
170
if (WITH_GPU )
162
-
163
- op_library (conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS
164
- vol2col depthwise_conv )
165
-
166
- op_library (edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function )
167
- op_library (pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling )
168
- op_library (conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
169
- conv_transpose_cudnn_op.cu.cc DEPS vol2col )
170
- file (APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d, CUDNN);\n " )
171
- file (APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(pool2d, CUDNN);\n " )
172
- file (APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d_transpose, CUDNN);\n " )
171
+ op_library (conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS
172
+ vol2col depthwise_conv )
173
+ op_library (edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function )
174
+ op_library (pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling )
175
+ op_library (conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
176
+ conv_transpose_cudnn_op.cu.cc DEPS vol2col )
173
177
else ()
174
- op_library (conv_op SRCS conv_op.cc DEPS vol2col )
175
- op_library (pool_op SRCS pool_op.cc DEPS pooling )
176
- op_library (conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col )
178
+ op_library (conv_op SRCS conv_op.cc DEPS vol2col )
179
+ op_library (pool_op SRCS pool_op.cc DEPS pooling )
180
+ op_library (conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col )
177
181
endif ()
178
182
179
183
cc_library (batch_size_like SRCS batch_size_like.cc DEPS op_registry )
0 commit comments