@@ -26,6 +26,8 @@ using Tensor = framework::Tensor;
26
26
using DataLayout = framework::DataLayout;
27
27
template <typename T>
28
28
using CudnnDataType = platform::CudnnDataType<T>;
29
+ template <typename T>
30
+ using bn_param_type = CudnnDataType<T>::bn_param_type;
29
31
30
32
void ExtractNCWHD (const framework::DDim &dims, const DataLayout &data_layout,
31
33
int *N, int *C, int *H, int *W, int *D) {
@@ -104,8 +106,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
104
106
CUDNN_ENFORCE (platform::dynload::cudnnSetTensorNdDescriptor (
105
107
data_desc_, CudnnDataType<T>::type,
106
108
x_dims.size () > 3 ? x_dims.size () : 4 , dims.data (), strides.data ()));
109
+ // Note: PERSISTENT not implemented for inference
107
110
CUDNN_ENFORCE (platform::dynload::cudnnDeriveBNTensorDescriptor (
108
- bn_param_desc_, data_desc_, mode_));
111
+ bn_param_desc_, data_desc_, is_test ? CUDNN_BATCHNORM_SPATIAL : mode_));
109
112
110
113
const auto *scale = ctx.Input <Tensor>(" Scale" );
111
114
const auto *bias = ctx.Input <Tensor>(" Bias" );
@@ -118,15 +121,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
118
121
119
122
// alloc memory
120
123
y->mutable_data <T>(ctx.GetPlace ());
121
- mean_out->mutable_data <T >(ctx.GetPlace ());
122
- variance_out->mutable_data <T >(ctx.GetPlace ());
123
- saved_mean->mutable_data <T >(ctx.GetPlace ());
124
- saved_variance->mutable_data <T >(ctx.GetPlace ());
124
+ mean_out->mutable_data <bn_param_type<T> >(ctx.GetPlace ());
125
+ variance_out->mutable_data <bn_param_type<T> >(ctx.GetPlace ());
126
+ saved_mean->mutable_data <bn_param_type<T> >(ctx.GetPlace ());
127
+ saved_variance->mutable_data <bn_param_type<T> >(ctx.GetPlace ());
125
128
126
129
auto &dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
127
- math::SetConstant<platform::CUDADeviceContext, T > functor;
128
- functor (dev_ctx, saved_mean, static_cast <T >(0 ));
129
- functor (dev_ctx, saved_variance, static_cast <T >(0 ));
130
+ math::SetConstant<platform::CUDADeviceContext, bn_param_type<T> > functor;
131
+ functor (dev_ctx, saved_mean, static_cast <bn_param_type<T> >(0 ));
132
+ functor (dev_ctx, saved_variance, static_cast <bn_param_type<T> >(0 ));
130
133
131
134
auto handle = dev_ctx.cudnn_handle ();
132
135
@@ -147,8 +150,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
147
150
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne (),
148
151
CudnnDataType<T>::kZero (), data_desc_, x->template data <T>(),
149
152
data_desc_, y->template mutable_data <T>(ctx.GetPlace ()),
150
- bn_param_desc_, scale->template data <T>(), bias->template data <T>(),
151
- est_mean->template data <T>(), est_var->template data <T>(), epsilon));
153
+ bn_param_desc_, scale->template data <bn_param_type<T>>(),
154
+ bias->template data <bn_param_type<T>>(),
155
+ est_mean->template data <bn_param_type<T>>(),
156
+ est_var->template data <bn_param_type<T>>(), epsilon));
152
157
} else {
153
158
// Run training mode.
154
159
// obtain running mean and running inv var, and see if we need to
@@ -159,11 +164,14 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
159
164
handle, mode_, CudnnDataType<T>::kOne (), CudnnDataType<T>::kZero (),
160
165
data_desc_, x->template data <T>(), data_desc_,
161
166
y->template mutable_data <T>(ctx.GetPlace ()), bn_param_desc_,
162
- scale->template data <T>(), bias->template data <T>(), this_factor,
163
- mean_out->template mutable_data <T>(ctx.GetPlace ()),
164
- variance_out->template mutable_data <T>(ctx.GetPlace ()), epsilon,
165
- saved_mean->template mutable_data <T>(ctx.GetPlace ()),
166
- saved_variance->template mutable_data <T>(ctx.GetPlace ())));
167
+ scale->template data <bn_param_type<T>>(),
168
+ bias->template data <bn_param_type<T>>(), this_factor,
169
+ mean_out->template mutable_data <bn_param_type<T>>(ctx.GetPlace ()),
170
+ variance_out->template mutable_data <bn_param_type<T>>(ctx.GetPlace ()),
171
+ epsilon,
172
+ saved_mean->template mutable_data <bn_param_type<T>>(ctx.GetPlace ()),
173
+ saved_variance->template mutable_data <bn_param_type<T>>(
174
+ ctx.GetPlace ())));
167
175
}
168
176
169
177
// clean when exit.
0 commit comments