Skip to content

Commit ed2bc19

Browse files
authored
Merge pull request #9176 from kexinzhao/batch_norm_fp16
Add float16 support to batch norm operator
2 parents cd07c0f + 6ec0f91 commit ed2bc19

File tree

6 files changed

+244
-25
lines changed

6 files changed

+244
-25
lines changed

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel {
8080
ctx->SetOutputDim("SavedVariance", {C});
8181
ctx->ShareLoD("X", "Y");
8282
}
83+
84+
protected:
85+
framework::OpKernelType GetExpectedKernelType(
86+
const framework::ExecutionContext &ctx) const override {
87+
auto input_data_type =
88+
framework::ToDataType(ctx.Input<Tensor>("X")->type());
89+
// For float or float16 input tensor, the type of the scale, bias, mean,
90+
// and var tensors should both be float.
91+
auto bn_param_type = framework::proto::VarType::FP32;
92+
PADDLE_ENFORCE_EQ(bn_param_type,
93+
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
94+
"Scale input should be of float type");
95+
PADDLE_ENFORCE_EQ(bn_param_type,
96+
framework::ToDataType(ctx.Input<Tensor>("Bias")->type()),
97+
"Bias input should be of float type");
98+
PADDLE_ENFORCE_EQ(bn_param_type,
99+
framework::ToDataType(ctx.Input<Tensor>("Mean")->type()),
100+
"Mean input should be of float type");
101+
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(
102+
ctx.Input<Tensor>("Variance")->type()),
103+
"Variance input should be of float type");
104+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
105+
}
83106
};
84107

85108
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {

paddle/fluid/operators/batch_norm_op.cu.cc

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include <cfloat>
1919
#include "paddle/fluid/operators/math/math_function.h"
2020
#include "paddle/fluid/platform/cudnn_helper.h"
21+
#include "paddle/fluid/platform/float16.h"
2122

2223
namespace paddle {
2324
namespace operators {
@@ -26,6 +27,8 @@ using Tensor = framework::Tensor;
2627
using DataLayout = framework::DataLayout;
2728
template <typename T>
2829
using CudnnDataType = platform::CudnnDataType<T>;
30+
template <typename T>
31+
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
2932

3033
void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
3134
int *N, int *C, int *H, int *W, int *D) {
@@ -104,8 +107,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
104107
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
105108
data_desc_, CudnnDataType<T>::type,
106109
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
110+
// Note: PERSISTENT not implemented for inference
107111
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
108-
bn_param_desc_, data_desc_, mode_));
112+
bn_param_desc_, data_desc_, is_test ? CUDNN_BATCHNORM_SPATIAL : mode_));
109113

110114
const auto *scale = ctx.Input<Tensor>("Scale");
111115
const auto *bias = ctx.Input<Tensor>("Bias");
@@ -118,15 +122,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
118122

119123
// alloc memory
120124
y->mutable_data<T>(ctx.GetPlace());
121-
mean_out->mutable_data<T>(ctx.GetPlace());
122-
variance_out->mutable_data<T>(ctx.GetPlace());
123-
saved_mean->mutable_data<T>(ctx.GetPlace());
124-
saved_variance->mutable_data<T>(ctx.GetPlace());
125+
mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
126+
variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
127+
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
128+
saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
125129

126130
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
127-
math::SetConstant<platform::CUDADeviceContext, T> functor;
128-
functor(dev_ctx, saved_mean, 0);
129-
functor(dev_ctx, saved_variance, 0);
131+
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
132+
functor;
133+
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
134+
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
130135

131136
auto handle = dev_ctx.cudnn_handle();
132137

@@ -147,8 +152,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
147152
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
148153
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
149154
data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
150-
bn_param_desc_, scale->template data<T>(), bias->template data<T>(),
151-
est_mean->template data<T>(), est_var->template data<T>(), epsilon));
155+
bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
156+
bias->template data<BatchNormParamType<T>>(),
157+
est_mean->template data<BatchNormParamType<T>>(),
158+
est_var->template data<BatchNormParamType<T>>(), epsilon));
152159
} else {
153160
// Run training mode.
154161
// obtain running mean and running inv var, and see if we need to
@@ -159,11 +166,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
159166
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
160167
data_desc_, x->template data<T>(), data_desc_,
161168
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
162-
scale->template data<T>(), bias->template data<T>(), this_factor,
163-
mean_out->template mutable_data<T>(ctx.GetPlace()),
164-
variance_out->template mutable_data<T>(ctx.GetPlace()), epsilon,
165-
saved_mean->template mutable_data<T>(ctx.GetPlace()),
166-
saved_variance->template mutable_data<T>(ctx.GetPlace())));
169+
scale->template data<BatchNormParamType<T>>(),
170+
bias->template data<BatchNormParamType<T>>(), this_factor,
171+
mean_out->template mutable_data<BatchNormParamType<T>>(
172+
ctx.GetPlace()),
173+
variance_out->template mutable_data<BatchNormParamType<T>>(
174+
ctx.GetPlace()),
175+
epsilon, saved_mean->template mutable_data<BatchNormParamType<T>>(
176+
ctx.GetPlace()),
177+
saved_variance->template mutable_data<BatchNormParamType<T>>(
178+
ctx.GetPlace())));
167179
}
168180

169181
// clean when exit.
@@ -270,9 +282,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
270282
} // namespace paddle
271283

272284
namespace ops = paddle::operators;
285+
namespace plat = paddle::platform;
273286
REGISTER_OP_CUDA_KERNEL(
274-
batch_norm,
275-
ops::BatchNormKernel<paddle::platform::CUDADeviceContext, float>);
287+
batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
288+
ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
276289
REGISTER_OP_CUDA_KERNEL(
277-
batch_norm_grad,
278-
ops::BatchNormGradKernel<paddle::platform::CUDADeviceContext, float>);
290+
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>);

paddle/fluid/operators/math/math_function.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ void axpy<platform::CPUDeviceContext, double>(
278278
cblas_daxpy(n, alpha, x, 1, y, 1);
279279
}
280280

281+
template struct SetConstant<platform::CPUDeviceContext, platform::float16>;
281282
template struct SetConstant<platform::CPUDeviceContext, float>;
282283
template struct SetConstant<platform::CPUDeviceContext, double>;
283284
template struct SetConstant<platform::CPUDeviceContext, int>;

paddle/fluid/operators/math/math_function.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ void axpy<platform::CUDADeviceContext, double>(
348348
&alpha, x, 1, y, 1));
349349
}
350350

351+
template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
351352
template struct SetConstant<platform::CUDADeviceContext, float>;
352353
template struct SetConstant<platform::CUDADeviceContext, double>;
353354
template struct SetConstant<platform::CUDADeviceContext, int>;

paddle/fluid/platform/cudnn_helper.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ class CudnnDataType<float16> {
8686
public:
8787
static const cudnnDataType_t type = CUDNN_DATA_HALF;
8888
// The scaling param type is float for HALF and FLOAT tensors
89-
typedef const float ScalingParamType;
89+
using ScalingParamType = const float;
90+
using BatchNormParamType = float;
9091
static ScalingParamType* kOne() {
9192
static ScalingParamType v = 1.0;
9293
return &v;
@@ -101,7 +102,8 @@ template <>
101102
class CudnnDataType<float> {
102103
public:
103104
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
104-
typedef const float ScalingParamType;
105+
using ScalingParamType = const float;
106+
using BatchNormParamType = float;
105107
static ScalingParamType* kOne() {
106108
static ScalingParamType v = 1.0;
107109
return &v;
@@ -116,7 +118,8 @@ template <>
116118
class CudnnDataType<double> {
117119
public:
118120
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
119-
typedef const double ScalingParamType;
121+
using ScalingParamType = const double;
122+
using BatchNormParamType = double;
120123
static ScalingParamType* kOne() {
121124
static ScalingParamType v = 1.0;
122125
return &v;

python/paddle/fluid/tests/unittests/test_batch_norm_op.py

Lines changed: 182 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set):
3131
return backward_op
3232

3333

34+
def _reference_testing(x, scale, offset, mean, var, epsilon, data_format):
35+
x_shape = x.shape
36+
if len(x_shape) == 2:
37+
if data_format == "NCHW":
38+
x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
39+
else:
40+
x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
41+
42+
if data_format == "NCHW":
43+
n, c, h, w = x.shape
44+
mean_tile = np.reshape(mean, (1, c, 1, 1))
45+
mean_tile = np.tile(mean_tile, (n, 1, h, w))
46+
var_tile = np.reshape(var, (1, c, 1, 1))
47+
var_tile = np.tile(var_tile, (n, 1, h, w))
48+
normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon)
49+
scale_tile = np.reshape(scale, (1, c, 1, 1))
50+
scale_tile = np.tile(scale_tile, (n, 1, h, w))
51+
offset_tile = np.reshape(offset, (1, c, 1, 1))
52+
offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
53+
y = normalized * scale_tile + offset_tile
54+
elif data_format == "NHWC":
55+
normalized = (x - mean) / np.sqrt(var + epsilon)
56+
y = normalized * scale + offset
57+
else:
58+
raise ValueError("Unknown data order.")
59+
60+
if len(x_shape) == 2:
61+
y = np.reshape(y, x_shape)
62+
return y
63+
64+
3465
def _reference_training(x, scale, offset, epsilon, data_format):
3566
x_shape = x.shape
3667
if len(x_shape) == 2:
@@ -155,11 +186,159 @@ def __set_tensor__(name, data=None):
155186
__set_tensor__(output, data)
156187

157188

158-
class TestBatchNormOp(OpTest):
189+
class TestBatchNormOpInference(OpTest):
190+
def setUp(self):
191+
self.dtype = np.float32
192+
159193
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
160194
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
161195

162-
def test_python(self):
196+
def check_with_place(self, place, data_layout, dtype, shape):
197+
epsilon = 0.00001
198+
if len(shape) == 2:
199+
x_shape = shape
200+
c = x_shape[1]
201+
else:
202+
n, h, w, c = shape[0], shape[1], shape[2], shape[3]
203+
if data_layout == "NHWC":
204+
x_shape = [n, h, w, c]
205+
elif data_layout == "NCHW":
206+
x_shape = [n, c, h, w]
207+
else:
208+
raise ValueError("Unknown data layout.")
209+
scale_shape = [c]
210+
211+
x_val = np.random.random_sample(x_shape).astype(dtype)
212+
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
213+
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
214+
215+
mean = np.zeros(scale_shape).astype(np.float32)
216+
variance = np.ones(scale_shape).astype(np.float32)
217+
218+
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
219+
epsilon, data_layout).astype(dtype)
220+
221+
scope = core.Scope()
222+
223+
# create input
224+
x_tensor = create_or_get_tensor(scope, "x_val",
225+
OpTest.np_dtype_to_fluid_dtype(x_val),
226+
place)
227+
scale_tensor = create_or_get_tensor(
228+
scope, "scale_val",
229+
OpTest.np_dtype_to_fluid_dtype(scale_val), place)
230+
bias_tensor = create_or_get_tensor(
231+
scope, "bias_val", OpTest.np_dtype_to_fluid_dtype(bias_val), place)
232+
mean_tensor = create_or_get_tensor(scope, "mean",
233+
OpTest.np_dtype_to_fluid_dtype(mean),
234+
place)
235+
variance_tensor = create_or_get_tensor(
236+
scope, "variance", OpTest.np_dtype_to_fluid_dtype(variance), place)
237+
238+
# create output
239+
y_tensor = create_or_get_tensor(scope, "y_out", None, place)
240+
saved_mean_tensor = create_or_get_tensor(scope, "saved_mean", None,
241+
place)
242+
saved_variance_tensor = create_or_get_tensor(scope, "saved_variance",
243+
None, place)
244+
mean_out_tensor = mean_tensor
245+
variance_out_tensor = variance_tensor
246+
247+
batch_norm_op = Operator(
248+
"batch_norm",
249+
# inputs
250+
X="x_val",
251+
Scale="scale_val",
252+
Bias="bias_val",
253+
Mean="mean",
254+
Variance="variance",
255+
# outputs
256+
Y="y_out",
257+
MeanOut="mean",
258+
VarianceOut="variance",
259+
SavedMean="saved_mean",
260+
SavedVariance="saved_variance",
261+
# attrs
262+
is_test=True,
263+
data_layout=data_layout,
264+
epsilon=epsilon)
265+
266+
batch_norm_op.run(scope, place)
267+
268+
# check inference result
269+
self.__assert_close(
270+
y_tensor,
271+
y_out,
272+
"inference output are different at " + str(place) + ", " +
273+
data_layout + ", " + str(np.dtype(dtype)) +
274+
str(np.array(y_tensor)) + str(y_out),
275+
atol=1e-3)
276+
277+
def test_check_output(self):
278+
places = [core.CPUPlace()]
279+
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
280+
places.append(core.CUDAPlace(0))
281+
282+
for place in places:
283+
for data_format in ["NCHW", "NHWC"]:
284+
self.check_with_place(place, data_format, self.dtype,
285+
[2, 3, 4, 5])
286+
self.check_with_place(place, data_format, self.dtype, [2, 3])
287+
288+
289+
class TestFP16BatchNormOpInference(TestBatchNormOpInference):
290+
def setUp(self):
291+
self.dtype = np.float16
292+
293+
def test_check_output(self):
294+
places = []
295+
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
296+
place = core.CUDAPlace(0)
297+
if core.is_float16_supported(place):
298+
places.append(place)
299+
300+
for place in places:
301+
for data_format in ["NCHW", "NHWC"]:
302+
self.check_with_place(place, data_format, self.dtype,
303+
[2, 3, 4, 5])
304+
self.check_with_place(place, data_format, self.dtype, [2, 3])
305+
306+
307+
class TestBatchNormOpTraining(OpTest):
308+
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
309+
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
310+
311+
def test_python_testing(self):
312+
data_format = "NHWC"
313+
epsilon = 0.00001
314+
315+
n, h, w, c = 2, 3, 4, 5
316+
x_shape = [n, h, w, c]
317+
scale_shape = [c]
318+
319+
x_val = np.random.random_sample(x_shape).astype(np.float32)
320+
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
321+
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
322+
323+
mean = np.zeros(scale_shape).astype(np.float32)
324+
variance = np.ones(scale_shape).astype(np.float32)
325+
326+
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
327+
epsilon, "NHWC")
328+
329+
# running N, C, H, W case
330+
# should produce the same results
331+
x_shape2 = [n, c, h, w]
332+
x_val2 = np.transpose(x_val, (0, 3, 1, 2))
333+
y_out2 = _reference_testing(x_val2, scale_val, bias_val, mean, variance,
334+
epsilon, "NCHW")
335+
336+
# transfer (N, C, H, W) back to (N, H, W, C)
337+
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
338+
self.__assert_close(y_out, y_out2_trans, "inference output")
339+
print 'python: NHWC, NCHW, inference checking passed'
340+
341+
def test_python_training(self):
163342
data_format = "NHWC"
164343
epsilon = 0.00001
165344
momentum = 0.9
@@ -197,7 +376,7 @@ def test_python(self):
197376

198377
# transfer (N, C, H, W) back to (N, H, W, C)
199378
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
200-
self.__assert_close(y_out, y_out2_trans, "batch variance")
379+
self.__assert_close(y_out, y_out2_trans, "batch output")
201380
print 'python: NHWC, NCHW, forward checking passed'
202381

203382
# test backward now

0 commit comments

Comments
 (0)