Skip to content

Commit e870947

Browse files
committed
fix batch norm fp16 param type
1 parent 3233b2b commit e870947

File tree

4 files changed

+69
-28
lines changed

4 files changed

+69
-28
lines changed

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel {
8080
ctx->SetOutputDim("SavedVariance", {C});
8181
ctx->ShareLoD("X", "Y");
8282
}
83+
84+
protected:
85+
framework::OpKernelType GetExpectedKernelType(
86+
const ExecutionContext &ctx) const override {
87+
auto input_data_type =
88+
framework::ToDataType(ctx.Input<Tensor>("X")->type());
89+
// For float or float16 input tensor, the type of the scale, bias, mean,
90+
// and var tensors should both be float.
91+
auto bn_param_type = framework::proto::VarType::FP32;
92+
PADDLE_ENFORCE_EQ(bn_param_type,
93+
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
94+
"Scale input should be of float type");
95+
PADDLE_ENFORCE_EQ(bn_param_type,
96+
framework::ToDataType(ctx.Input<Tensor>("Bias")->type()),
97+
"Bias input should be of float type");
98+
PADDLE_ENFORCE_EQ(bn_param_type,
99+
framework::ToDataType(ctx.Input<Tensor>("Mean")->type()),
100+
"Mean input should be of float type");
101+
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(
102+
ctx.Input<Tensor>("Variance")->type()),
103+
"Variance input should be of float type");
104+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
105+
}
83106
};
84107

85108
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {

paddle/fluid/operators/batch_norm_op.cu.cc

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ using Tensor = framework::Tensor;
2626
using DataLayout = framework::DataLayout;
2727
template <typename T>
2828
using CudnnDataType = platform::CudnnDataType<T>;
29+
template <typename T>
30+
using bn_param_type = CudnnDataType<T>::bn_param_type;
2931

3032
void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
3133
int *N, int *C, int *H, int *W, int *D) {
@@ -104,8 +106,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
104106
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
105107
data_desc_, CudnnDataType<T>::type,
106108
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
109+
// Note: PERSISTENT not implemented for inference
107110
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
108-
bn_param_desc_, data_desc_, mode_));
111+
bn_param_desc_, data_desc_, is_test ? CUDNN_BATCHNORM_SPATIAL : mode_));
109112

110113
const auto *scale = ctx.Input<Tensor>("Scale");
111114
const auto *bias = ctx.Input<Tensor>("Bias");
@@ -118,15 +121,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
118121

119122
// alloc memory
120123
y->mutable_data<T>(ctx.GetPlace());
121-
mean_out->mutable_data<T>(ctx.GetPlace());
122-
variance_out->mutable_data<T>(ctx.GetPlace());
123-
saved_mean->mutable_data<T>(ctx.GetPlace());
124-
saved_variance->mutable_data<T>(ctx.GetPlace());
124+
mean_out->mutable_data<bn_param_type<T>>(ctx.GetPlace());
125+
variance_out->mutable_data<bn_param_type<T>>(ctx.GetPlace());
126+
saved_mean->mutable_data<bn_param_type<T>>(ctx.GetPlace());
127+
saved_variance->mutable_data<bn_param_type<T>>(ctx.GetPlace());
125128

126129
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
127-
math::SetConstant<platform::CUDADeviceContext, T> functor;
128-
functor(dev_ctx, saved_mean, static_cast<T>(0));
129-
functor(dev_ctx, saved_variance, static_cast<T>(0));
130+
math::SetConstant<platform::CUDADeviceContext, bn_param_type<T>> functor;
131+
functor(dev_ctx, saved_mean, static_cast<bn_param_type<T>>(0));
132+
functor(dev_ctx, saved_variance, static_cast<bn_param_type<T>>(0));
130133

131134
auto handle = dev_ctx.cudnn_handle();
132135

@@ -147,8 +150,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
147150
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
148151
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
149152
data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
150-
bn_param_desc_, scale->template data<T>(), bias->template data<T>(),
151-
est_mean->template data<T>(), est_var->template data<T>(), epsilon));
153+
bn_param_desc_, scale->template data<bn_param_type<T>>(),
154+
bias->template data<bn_param_type<T>>(),
155+
est_mean->template data<bn_param_type<T>>(),
156+
est_var->template data<bn_param_type<T>>(), epsilon));
152157
} else {
153158
// Run training mode.
154159
// obtain running mean and running inv var, and see if we need to
@@ -159,11 +164,14 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
159164
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
160165
data_desc_, x->template data<T>(), data_desc_,
161166
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
162-
scale->template data<T>(), bias->template data<T>(), this_factor,
163-
mean_out->template mutable_data<T>(ctx.GetPlace()),
164-
variance_out->template mutable_data<T>(ctx.GetPlace()), epsilon,
165-
saved_mean->template mutable_data<T>(ctx.GetPlace()),
166-
saved_variance->template mutable_data<T>(ctx.GetPlace())));
167+
scale->template data<bn_param_type<T>>(),
168+
bias->template data<bn_param_type<T>>(), this_factor,
169+
mean_out->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
170+
variance_out->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
171+
epsilon,
172+
saved_mean->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
173+
saved_variance->template mutable_data<bn_param_type<T>>(
174+
ctx.GetPlace())));
167175
}
168176

169177
// clean when exit.

paddle/fluid/platform/cudnn_helper.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ template <>
8585
class CudnnDataType<float16> {
8686
public:
8787
static const cudnnDataType_t type = CUDNN_DATA_HALF;
88+
// cudnn batch norm requires that Scale, Bias, Mean, and Variance
89+
// to be FLOAT tensors when the input x is HALF tensor
90+
static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT;
8891
// The scaling param type is float for HALF and FLOAT tensors
8992
typedef const float ScalingParamType;
9093
static ScalingParamType* kOne() {
@@ -101,6 +104,7 @@ template <>
101104
class CudnnDataType<float> {
102105
public:
103106
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
107+
static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT;
104108
typedef const float ScalingParamType;
105109
static ScalingParamType* kOne() {
106110
static ScalingParamType v = 1.0;
@@ -116,6 +120,7 @@ template <>
116120
class CudnnDataType<double> {
117121
public:
118122
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
123+
static const cudnnDataType_t bn_param_type = CUDNN_DATA_DOUBLE;
119124
typedef const double ScalingParamType;
120125
static ScalingParamType* kOne() {
121126
static ScalingParamType v = 1.0;

python/paddle/fluid/tests/unittests/test_batch_norm_op.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def setUp(self):
193193
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
194194
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
195195

196-
def check_with_place(place, data_layout, dtype, shape):
196+
def check_with_place(self, place, data_layout, dtype, shape):
197197
epsilon = 0.00001
198198
if len(shape) == 2:
199199
x_shape = shape
@@ -209,11 +209,11 @@ def check_with_place(place, data_layout, dtype, shape):
209209
scale_shape = [c]
210210

211211
x_val = np.random.random_sample(x_shape).astype(dtype)
212-
scale_val = np.random.random_sample(scale_shape).astype(dtype)
213-
bias_val = np.random.random_sample(scale_shape).astype(dtype)
212+
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
213+
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
214214

215-
mean = np.zeros(scale_shape).astype(dtype)
216-
variance = np.ones(scale_shape).astype(dtype)
215+
mean = np.zeros(scale_shape).astype(np.float32)
216+
variance = np.ones(scale_shape).astype(np.float32)
217217

218218
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
219219
epsilon, data_layout).astype(dtype)
@@ -266,9 +266,13 @@ def check_with_place(place, data_layout, dtype, shape):
266266
batch_norm_op.run(scope, place)
267267

268268
# check inference result
269-
self.__assert_close(y_tensor, y_out,
270-
"inference output are different at " + str(place) +
271-
", " + data_layout + ", " + str(np.dtype(dtype)))
269+
self.__assert_close(
270+
y_tensor,
271+
y_out,
272+
"inference output are different at " + str(place) + ", " +
273+
data_layout + ", " + str(np.dtype(dtype)) +
274+
str(np.array(y_tensor)) + str(y_out),
275+
atol=2e-2)
272276

273277
def test_check_output(self):
274278
places = [core.CPUPlace()]
@@ -277,8 +281,9 @@ def test_check_output(self):
277281

278282
for place in places:
279283
for data_format in ["NCHW", "NHWC"]:
280-
check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
281-
check_with_place(place, data_format, self.dtype, [2, 3])
284+
self.check_with_place(place, data_format, self.dtype,
285+
[2, 3, 4, 5])
286+
self.check_with_place(place, data_format, self.dtype, [2, 3])
282287

283288

284289
class TestFP16BatchNormOpInference(TestBatchNormOpInference):
@@ -294,9 +299,9 @@ def test_check_output(self):
294299

295300
for place in places:
296301
for data_format in ["NCHW", "NHWC"]:
297-
check_output_with_place(place, data_format, self.dtype,
298-
[2, 3, 4, 5])
299-
check_output_with_place(place, data_format, self.dtype, [2, 3])
302+
self.check_with_place(place, data_format, self.dtype,
303+
[2, 3, 4, 5])
304+
self.check_with_place(place, data_format, self.dtype, [2, 3])
300305

301306

302307
class TestBatchNormOpTraining(OpTest):

0 commit comments

Comments
 (0)