@@ -11,6 +11,8 @@ function(op_library TARGET)
11
11
set (cc_srcs )
12
12
set (cu_srcs )
13
13
set (cu_cc_srcs )
14
+ set (cudnn_cu_cc_srcs )
15
+ set (CUDNN_FILE )
14
16
set (op_common_deps operator op_registry math_function )
15
17
set (options "" )
16
18
set (oneValueArgs "" )
@@ -30,10 +32,16 @@ function(op_library TARGET)
30
32
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /${TARGET}.cu )
31
33
list (APPEND cu_srcs ${TARGET} .cu )
32
34
endif ()
35
+ string (REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET} " )
36
+ if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /${CUDNN_FILE}.cu.cc )
37
+ list (APPEND cudnn_cu_cc_srcs ${CUDNN_FILE} .cu.cc )
38
+ endif ()
33
39
else ()
34
40
foreach (src ${op_library_SRCS} )
35
41
if (${src} MATCHES ".*\\ .cu$" )
36
42
list (APPEND cu_srcs ${src} )
43
+ elseif (${src} MATCHES ".*_cudnn_op.cu.cc$" )
44
+ list (APPEND cudnn_cu_cc_srcs ${src} )
37
45
elseif (${src} MATCHES ".*\\ .cu.cc$" )
38
46
list (APPEND cu_cc_srcs ${src} )
39
47
elseif (${src} MATCHES ".*\\ .cc$" )
@@ -54,7 +62,7 @@ function(op_library TARGET)
54
62
set (DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE )
55
63
endif ()
56
64
if (WITH_GPU )
57
- nv_library (${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
65
+ nv_library (${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${ cu_srcs} DEPS ${op_library_DEPS}
58
66
${op_common_deps} )
59
67
else ()
60
68
cc_library (${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS}
@@ -98,6 +106,12 @@ function(op_library TARGET)
98
106
set (pybind_flag 1 )
99
107
endif ()
100
108
109
+ # pybind USE_OP_DEVICE_KERNEL for CUDNN
110
+ list (LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len )
111
+ if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0 )
112
+ file (APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET} , CUDNN);\n " )
113
+ endif ()
114
+
101
115
# pybind USE_OP
102
116
if (${pybind_flag} EQUAL 0 )
103
117
file (APPEND ${pybind_file} "USE_OP(${TARGET} );\n " )
@@ -152,43 +166,24 @@ op_library(lstm_op DEPS sequence2batch lstm_compute)
152
166
op_library (lstmp_op DEPS sequence2batch lstm_compute )
153
167
op_library (gru_op DEPS sequence2batch gru_compute )
154
168
op_library (recurrent_op DEPS executor )
155
- op_library (warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function )
169
+ op_library (warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale )
156
170
op_library (cos_sim_op DEPS cos_sim_functor )
157
171
op_library (parallel_do_op DEPS executor )
158
172
op_library (create_reader_op DEPS reader )
159
173
160
174
# Regist multiple Kernel to pybind
161
175
if (WITH_GPU )
162
-
163
- op_library (conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS
164
- vol2col depthwise_conv )
165
-
166
- op_library (edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function )
167
- op_library (pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling )
168
- op_library (conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
169
- conv_transpose_cudnn_op.cu.cc DEPS vol2col )
170
- file (APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d, CUDNN);\n " )
171
- file (APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(pool2d, CUDNN);\n " )
172
- file (APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d_transpose, CUDNN);\n " )
176
+ op_library (conv_op DEPS vol2col depthwise_conv )
173
177
else ()
174
- op_library (conv_op SRCS conv_op.cc DEPS vol2col )
175
- op_library (pool_op SRCS pool_op.cc DEPS pooling )
176
- op_library (conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col )
178
+ op_library (conv_op DEPS vol2col )
177
179
endif ()
180
+ op_library (pool_op DEPS pooling )
181
+ op_library (conv_transpose_op DEPS vol2col )
178
182
179
183
cc_library (batch_size_like SRCS batch_size_like.cc DEPS op_registry )
180
-
181
- op_library (fill_constant_batch_size_like_op
182
- SRCS fill_constant_batch_size_like_op.cc fill_constant_batch_size_like_op.cu.cc
183
- DEPS batch_size_like )
184
-
185
- op_library (uniform_random_batch_size_like_op
186
- SRCS uniform_random_batch_size_like_op.cc
187
- DEPS batch_size_like uniform_random_op )
188
-
189
- op_library (gaussian_random_batch_size_like_op
190
- SRCS gaussian_random_batch_size_like_op.cc
191
- DEPS batch_size_like gaussian_random_op )
184
+ op_library (fill_constant_batch_size_like_op DEPS batch_size_like )
185
+ op_library (uniform_random_batch_size_like_op DEPS batch_size_like uniform_random_op )
186
+ op_library (gaussian_random_batch_size_like_op DEPS batch_size_like gaussian_random_op )
192
187
193
188
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
194
189
op_library (save_op DEPS lod_tensor )
0 commit comments