@@ -28,7 +28,7 @@ using DataLayout = framework::DataLayout;
28
28
template <typename T>
29
29
using CudnnDataType = platform::CudnnDataType<T>;
30
30
template <typename T>
31
- using ScalingParamType = typename CudnnDataType<T>::ScalingParamType ;
31
+ using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType ;
32
32
33
33
void ExtractNCWHD (const framework::DDim &dims, const DataLayout &data_layout,
34
34
int *N, int *C, int *H, int *W, int *D) {
@@ -122,15 +122,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
122
122
123
123
// alloc memory
124
124
y->mutable_data <T>(ctx.GetPlace ());
125
- mean_out->mutable_data <ScalingParamType <T>>(ctx.GetPlace ());
126
- variance_out->mutable_data <ScalingParamType <T>>(ctx.GetPlace ());
127
- saved_mean->mutable_data <ScalingParamType <T>>(ctx.GetPlace ());
128
- saved_variance->mutable_data <ScalingParamType <T>>(ctx.GetPlace ());
125
+ mean_out->mutable_data <BatchNormParamType <T>>(ctx.GetPlace ());
126
+ variance_out->mutable_data <BatchNormParamType <T>>(ctx.GetPlace ());
127
+ saved_mean->mutable_data <BatchNormParamType <T>>(ctx.GetPlace ());
128
+ saved_variance->mutable_data <BatchNormParamType <T>>(ctx.GetPlace ());
129
129
130
130
auto &dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
131
- math::SetConstant<platform::CUDADeviceContext, ScalingParamType<T>> functor;
132
- functor (dev_ctx, saved_mean, static_cast <ScalingParamType<T>>(0 ));
133
- functor (dev_ctx, saved_variance, static_cast <ScalingParamType<T>>(0 ));
131
+ math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
132
+ functor;
133
+ functor (dev_ctx, saved_mean, static_cast <BatchNormParamType<T>>(0 ));
134
+ functor (dev_ctx, saved_variance, static_cast <BatchNormParamType<T>>(0 ));
134
135
135
136
auto handle = dev_ctx.cudnn_handle ();
136
137
@@ -151,10 +152,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
151
152
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne (),
152
153
CudnnDataType<T>::kZero (), data_desc_, x->template data <T>(),
153
154
data_desc_, y->template mutable_data <T>(ctx.GetPlace ()),
154
- bn_param_desc_, scale->template data <ScalingParamType <T>>(),
155
- bias->template data <ScalingParamType <T>>(),
156
- est_mean->template data <ScalingParamType <T>>(),
157
- est_var->template data <ScalingParamType <T>>(), epsilon));
155
+ bn_param_desc_, scale->template data <BatchNormParamType <T>>(),
156
+ bias->template data <BatchNormParamType <T>>(),
157
+ est_mean->template data <BatchNormParamType <T>>(),
158
+ est_var->template data <BatchNormParamType <T>>(), epsilon));
158
159
} else {
159
160
// Run training mode.
160
161
// obtain running mean and running inv var, and see if we need to
@@ -165,14 +166,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
165
166
handle, mode_, CudnnDataType<T>::kOne (), CudnnDataType<T>::kZero (),
166
167
data_desc_, x->template data <T>(), data_desc_,
167
168
y->template mutable_data <T>(ctx.GetPlace ()), bn_param_desc_,
168
- scale->template data <ScalingParamType<T>>(),
169
- bias->template data <ScalingParamType<T>>(), this_factor,
170
- mean_out->template mutable_data <ScalingParamType<T>>(ctx.GetPlace ()),
171
- variance_out->template mutable_data <ScalingParamType<T>>(
169
+ scale->template data <BatchNormParamType<T>>(),
170
+ bias->template data <BatchNormParamType<T>>(), this_factor,
171
+ mean_out->template mutable_data <BatchNormParamType<T>>(
172
172
ctx.GetPlace ()),
173
- epsilon, saved_mean->template mutable_data <ScalingParamType<T>>(
173
+ variance_out->template mutable_data <BatchNormParamType<T>>(
174
+ ctx.GetPlace ()),
175
+ epsilon, saved_mean->template mutable_data <BatchNormParamType<T>>(
174
176
ctx.GetPlace ()),
175
- saved_variance->template mutable_data <ScalingParamType <T>>(
177
+ saved_variance->template mutable_data <BatchNormParamType <T>>(
176
178
ctx.GetPlace ())));
177
179
}
178
180
0 commit comments