Skip to content

Commit 4a29292

Browse files
authored
Merge pull request #14752 from phlrain/fix_cudnn_5
fix cudnn 5 suuport; test=release/1.2
2 parents 08f927d + 22d1d5c commit 4a29292

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

paddle/fluid/operators/cudnn_lstm_op.cu.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,19 @@ struct CudnnRNNCache {
177177
seed_));
178178

179179
CUDNN_ENFORCE(platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_));
180+
181+
#if CUDNN_VERSION >= 6000
180182
CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor_v6(
181183
handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_,
182184
CUDNN_LINEAR_INPUT,
183185
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
184186
CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT));
187+
#else
188+
CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor(
189+
rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT,
190+
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
191+
CUDNN_DATA_FLOAT));
192+
#endif
185193

186194
CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&w_desc_));
187195
CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_));

paddle/fluid/platform/dynload/cudnn.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
125125
__macro(cudnnRNNBackwardWeights); \
126126
__macro(cudnnRNNForwardInference); \
127127
__macro(cudnnDestroyDropoutDescriptor); \
128-
__macro(cudnnDestroyRNNDescriptor); \
129-
__macro(cudnnSetRNNDescriptor_v6);
128+
__macro(cudnnDestroyRNNDescriptor);
130129

131130
CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
132131

@@ -165,6 +164,12 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
165164
CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
166165
#endif
167166

167+
// APIs in R6
168+
#if CUDNN_VERSION >= 6000
169+
#define CUDNN_DNN_ROUTINE_EACH_R6(__macro) __macro(cudnnSetRNNDescriptor_v6);
170+
CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
171+
#endif
172+
168173
#if CUDNN_VERSION >= 7001
169174
#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \
170175
__macro(cudnnSetConvolutionGroupCount); \

0 commit comments

Comments
 (0)