@@ -18,6 +18,7 @@ limitations under the License. */
18
18
#include < cfloat>
19
19
#include " paddle/fluid/operators/math/math_function.h"
20
20
#include " paddle/fluid/platform/cudnn_helper.h"
21
+ #include " paddle/fluid/platform/float16.h"
21
22
22
23
namespace paddle {
23
24
namespace operators {
@@ -27,7 +28,7 @@ using DataLayout = framework::DataLayout;
27
28
template <typename T>
28
29
using CudnnDataType = platform::CudnnDataType<T>;
29
30
template <typename T>
30
- using bn_param_type = CudnnDataType<T>::bn_param_type ;
31
+ using ScalingParamType = typename CudnnDataType<T>::ScalingParamType ;
31
32
32
33
void ExtractNCWHD (const framework::DDim &dims, const DataLayout &data_layout,
33
34
int *N, int *C, int *H, int *W, int *D) {
@@ -121,15 +122,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
121
122
122
123
// alloc memory
123
124
y->mutable_data <T>(ctx.GetPlace ());
124
- mean_out->mutable_data <bn_param_type <T>>(ctx.GetPlace ());
125
- variance_out->mutable_data <bn_param_type <T>>(ctx.GetPlace ());
126
- saved_mean->mutable_data <bn_param_type <T>>(ctx.GetPlace ());
127
- saved_variance->mutable_data <bn_param_type <T>>(ctx.GetPlace ());
125
+ mean_out->mutable_data <ScalingParamType <T>>(ctx.GetPlace ());
126
+ variance_out->mutable_data <ScalingParamType <T>>(ctx.GetPlace ());
127
+ saved_mean->mutable_data <ScalingParamType <T>>(ctx.GetPlace ());
128
+ saved_variance->mutable_data <ScalingParamType <T>>(ctx.GetPlace ());
128
129
129
130
auto &dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
130
- math::SetConstant<platform::CUDADeviceContext, bn_param_type <T>> functor;
131
- functor (dev_ctx, saved_mean, static_cast <bn_param_type <T>>(0 ));
132
- functor (dev_ctx, saved_variance, static_cast <bn_param_type <T>>(0 ));
131
+ math::SetConstant<platform::CUDADeviceContext, ScalingParamType <T>> functor;
132
+ functor (dev_ctx, saved_mean, static_cast <ScalingParamType <T>>(0 ));
133
+ functor (dev_ctx, saved_variance, static_cast <ScalingParamType <T>>(0 ));
133
134
134
135
auto handle = dev_ctx.cudnn_handle ();
135
136
@@ -150,10 +151,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
150
151
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne (),
151
152
CudnnDataType<T>::kZero (), data_desc_, x->template data <T>(),
152
153
data_desc_, y->template mutable_data <T>(ctx.GetPlace ()),
153
- bn_param_desc_, scale->template data <bn_param_type <T>>(),
154
- bias->template data <bn_param_type <T>>(),
155
- est_mean->template data <bn_param_type <T>>(),
156
- est_var->template data <bn_param_type <T>>(), epsilon));
154
+ bn_param_desc_, scale->template data <ScalingParamType <T>>(),
155
+ bias->template data <ScalingParamType <T>>(),
156
+ est_mean->template data <ScalingParamType <T>>(),
157
+ est_var->template data <ScalingParamType <T>>(), epsilon));
157
158
} else {
158
159
// Run training mode.
159
160
// obtain running mean and running inv var, and see if we need to
@@ -164,13 +165,14 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
164
165
handle, mode_, CudnnDataType<T>::kOne (), CudnnDataType<T>::kZero (),
165
166
data_desc_, x->template data <T>(), data_desc_,
166
167
y->template mutable_data <T>(ctx.GetPlace ()), bn_param_desc_,
167
- scale->template data <bn_param_type<T>>(),
168
- bias->template data <bn_param_type<T>>(), this_factor,
169
- mean_out->template mutable_data <bn_param_type<T>>(ctx.GetPlace ()),
170
- variance_out->template mutable_data <bn_param_type<T>>(ctx.GetPlace ()),
171
- epsilon,
172
- saved_mean->template mutable_data <bn_param_type<T>>(ctx.GetPlace ()),
173
- saved_variance->template mutable_data <bn_param_type<T>>(
168
+ scale->template data <ScalingParamType<T>>(),
169
+ bias->template data <ScalingParamType<T>>(), this_factor,
170
+ mean_out->template mutable_data <ScalingParamType<T>>(ctx.GetPlace ()),
171
+ variance_out->template mutable_data <ScalingParamType<T>>(
172
+ ctx.GetPlace ()),
173
+ epsilon, saved_mean->template mutable_data <ScalingParamType<T>>(
174
+ ctx.GetPlace ()),
175
+ saved_variance->template mutable_data <ScalingParamType<T>>(
174
176
ctx.GetPlace ())));
175
177
}
176
178
0 commit comments