Skip to content

Commit ffa22a5

Browse files
committed
fix scaling param type
1 parent e870947 commit ffa22a5

File tree

2 files changed

+21
-24
lines changed

2 files changed

+21
-24
lines changed

paddle/fluid/operators/batch_norm_op.cu.cc

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include <cfloat>
1919
#include "paddle/fluid/operators/math/math_function.h"
2020
#include "paddle/fluid/platform/cudnn_helper.h"
21+
#include "paddle/fluid/platform/float16.h"
2122

2223
namespace paddle {
2324
namespace operators {
@@ -27,7 +28,7 @@ using DataLayout = framework::DataLayout;
2728
template <typename T>
2829
using CudnnDataType = platform::CudnnDataType<T>;
2930
template <typename T>
30-
using bn_param_type = CudnnDataType<T>::bn_param_type;
31+
using ScalingParamType = typename CudnnDataType<T>::ScalingParamType;
3132

3233
void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
3334
int *N, int *C, int *H, int *W, int *D) {
@@ -121,15 +122,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
121122

122123
// alloc memory
123124
y->mutable_data<T>(ctx.GetPlace());
124-
mean_out->mutable_data<bn_param_type<T>>(ctx.GetPlace());
125-
variance_out->mutable_data<bn_param_type<T>>(ctx.GetPlace());
126-
saved_mean->mutable_data<bn_param_type<T>>(ctx.GetPlace());
127-
saved_variance->mutable_data<bn_param_type<T>>(ctx.GetPlace());
125+
mean_out->mutable_data<ScalingParamType<T>>(ctx.GetPlace());
126+
variance_out->mutable_data<ScalingParamType<T>>(ctx.GetPlace());
127+
saved_mean->mutable_data<ScalingParamType<T>>(ctx.GetPlace());
128+
saved_variance->mutable_data<ScalingParamType<T>>(ctx.GetPlace());
128129

129130
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
130-
math::SetConstant<platform::CUDADeviceContext, bn_param_type<T>> functor;
131-
functor(dev_ctx, saved_mean, static_cast<bn_param_type<T>>(0));
132-
functor(dev_ctx, saved_variance, static_cast<bn_param_type<T>>(0));
131+
math::SetConstant<platform::CUDADeviceContext, ScalingParamType<T>> functor;
132+
functor(dev_ctx, saved_mean, static_cast<ScalingParamType<T>>(0));
133+
functor(dev_ctx, saved_variance, static_cast<ScalingParamType<T>>(0));
133134

134135
auto handle = dev_ctx.cudnn_handle();
135136

@@ -150,10 +151,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
150151
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
151152
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
152153
data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
153-
bn_param_desc_, scale->template data<bn_param_type<T>>(),
154-
bias->template data<bn_param_type<T>>(),
155-
est_mean->template data<bn_param_type<T>>(),
156-
est_var->template data<bn_param_type<T>>(), epsilon));
154+
bn_param_desc_, scale->template data<ScalingParamType<T>>(),
155+
bias->template data<ScalingParamType<T>>(),
156+
est_mean->template data<ScalingParamType<T>>(),
157+
est_var->template data<ScalingParamType<T>>(), epsilon));
157158
} else {
158159
// Run training mode.
159160
// obtain running mean and running inv var, and see if we need to
@@ -164,13 +165,14 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
164165
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
165166
data_desc_, x->template data<T>(), data_desc_,
166167
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
167-
scale->template data<bn_param_type<T>>(),
168-
bias->template data<bn_param_type<T>>(), this_factor,
169-
mean_out->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
170-
variance_out->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
171-
epsilon,
172-
saved_mean->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
173-
saved_variance->template mutable_data<bn_param_type<T>>(
168+
scale->template data<ScalingParamType<T>>(),
169+
bias->template data<ScalingParamType<T>>(), this_factor,
170+
mean_out->template mutable_data<ScalingParamType<T>>(ctx.GetPlace()),
171+
variance_out->template mutable_data<ScalingParamType<T>>(
172+
ctx.GetPlace()),
173+
epsilon, saved_mean->template mutable_data<ScalingParamType<T>>(
174+
ctx.GetPlace()),
175+
saved_variance->template mutable_data<ScalingParamType<T>>(
174176
ctx.GetPlace())));
175177
}
176178

paddle/fluid/platform/cudnn_helper.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,6 @@ template <>
8585
class CudnnDataType<float16> {
8686
public:
8787
static const cudnnDataType_t type = CUDNN_DATA_HALF;
88-
// cudnn batch norm requires that Scale, Bias, Mean, and Variance
89-
// to be FLOAT tensors when the input x is HALF tensor
90-
static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT;
9188
// The scaling param type is float for HALF and FLOAT tensors
9289
typedef const float ScalingParamType;
9390
static ScalingParamType* kOne() {
@@ -104,7 +101,6 @@ template <>
104101
class CudnnDataType<float> {
105102
public:
106103
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
107-
static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT;
108104
typedef const float ScalingParamType;
109105
static ScalingParamType* kOne() {
110106
static ScalingParamType v = 1.0;
@@ -120,7 +116,6 @@ template <>
120116
class CudnnDataType<double> {
121117
public:
122118
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
123-
static const cudnnDataType_t bn_param_type = CUDNN_DATA_DOUBLE;
124119
typedef const double ScalingParamType;
125120
static ScalingParamType* kOne() {
126121
static ScalingParamType v = 1.0;

0 commit comments

Comments
 (0)