@@ -12,6 +12,7 @@ function(op_library TARGET)
12
12
set (cu_srcs )
13
13
set (cu_cc_srcs )
14
14
set (cudnn_cu_cc_srcs )
15
+ set (CUDNN_FILE )
15
16
set (op_common_deps operator op_registry math_function )
16
17
set (options "" )
17
18
set (oneValueArgs "" )
@@ -31,6 +32,10 @@ function(op_library TARGET)
31
32
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /${TARGET}.cu )
32
33
list (APPEND cu_srcs ${TARGET} .cu )
33
34
endif ()
35
+ string (REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET} " )
36
+ if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /${CUDNN_FILE}.cu.cc )
37
+ list (APPEND cudnn_cu_cc_srcs ${CUDNN_FILE} .cu.cc )
38
+ endif ()
34
39
else ()
35
40
foreach (src ${op_library_SRCS} )
36
41
if (${src} MATCHES ".*\\ .cu$" )
@@ -103,7 +108,7 @@ function(op_library TARGET)
103
108
104
109
# pybind USE_OP_DEVICE_KERNEL for CUDNN
105
110
list (LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len )
106
- if (${cudnn_cu_cc_srcs_len} GREATER 0 )
111
+ if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0 )
107
112
file (APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET} , CUDNN);\n " )
108
113
endif ()
109
114
@@ -161,38 +166,24 @@ op_library(lstm_op DEPS sequence2batch lstm_compute)
161
166
op_library (lstmp_op DEPS sequence2batch lstm_compute )
162
167
op_library (gru_op DEPS sequence2batch gru_compute )
163
168
op_library (recurrent_op DEPS executor )
164
- op_library (warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function )
169
+ op_library (warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale )
165
170
op_library (cos_sim_op DEPS cos_sim_functor )
166
171
op_library (parallel_do_op DEPS executor )
167
172
op_library (create_reader_op DEPS reader )
168
173
169
174
# Regist multiple Kernel to pybind
170
175
if (WITH_GPU )
171
- op_library (conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS
172
- vol2col depthwise_conv )
173
- op_library (edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function )
174
- op_library (pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling )
175
- op_library (conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
176
- conv_transpose_cudnn_op.cu.cc DEPS vol2col )
176
+ op_library (conv_op DEPS vol2col depthwise_conv )
177
177
else ()
178
- op_library (conv_op SRCS conv_op.cc DEPS vol2col )
179
- op_library (pool_op SRCS pool_op.cc DEPS pooling )
180
- op_library (conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col )
178
+ op_library (conv_op DEPS vol2col )
181
179
endif ()
180
+ op_library (pool_op DEPS pooling )
181
+ op_library (conv_transpose_op DEPS vol2col )
182
182
183
183
cc_library (batch_size_like SRCS batch_size_like.cc DEPS op_registry )
184
-
185
- op_library (fill_constant_batch_size_like_op
186
- SRCS fill_constant_batch_size_like_op.cc fill_constant_batch_size_like_op.cu.cc
187
- DEPS batch_size_like )
188
-
189
- op_library (uniform_random_batch_size_like_op
190
- SRCS uniform_random_batch_size_like_op.cc
191
- DEPS batch_size_like uniform_random_op )
192
-
193
- op_library (gaussian_random_batch_size_like_op
194
- SRCS gaussian_random_batch_size_like_op.cc
195
- DEPS batch_size_like gaussian_random_op )
184
+ op_library (fill_constant_batch_size_like_op DEPS batch_size_like )
185
+ op_library (uniform_random_batch_size_like_op DEPS batch_size_like uniform_random_op )
186
+ op_library (gaussian_random_batch_size_like_op DEPS batch_size_like gaussian_random_op )
196
187
197
188
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
198
189
op_library (save_op DEPS lod_tensor )
0 commit comments