Skip to content

Commit 56b723c

Browse files
authored
Cudnn batch norm op (#5067)
* init cudnn batch norm op * rename batch_norm_cudnn_op.cc batch_norm_op.cu * correct name style * add ExtractNCWHD, simplify code * fix ExtractNCWHD * use CUDNN_ENFORCE instead of PADDLE_ENFORCE
1 parent 629cbda commit 56b723c

File tree

3 files changed

+322
-0
lines changed

3 files changed

+322
-0
lines changed

paddle/operators/batch_norm_op.cu

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/batch_norm_op.h"
16+
17+
#include <cfloat>
18+
#include "paddle/operators/math/math_function.h"
19+
#include "paddle/platform/cudnn_helper.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using Tensor = framework::Tensor;
25+
template <typename T>
26+
using CudnnDataType = platform::CudnnDataType<T>;
27+
28+
void ExtractNCWHD(const framework::DDim &dims,
29+
const TensorFormat &tensor_format, int *N, int *C, int *H,
30+
int *W, int *D) {
31+
*N = dims[0];
32+
*C = tensor_format == TensorFormat::NCHW ? dims[1] : dims[dims.size() - 1];
33+
*H = tensor_format == TensorFormat::NCHW ? dims[2] : dims[1];
34+
*W = dims.size() > 3
35+
? (tensor_format == TensorFormat::NCHW ? dims[3] : dims[2])
36+
: 1;
37+
*D = dims.size() > 4
38+
? (tensor_format == TensorFormat::NCHW ? dims[4] : dims[3])
39+
: 1;
40+
}
41+
42+
template <typename T>
43+
class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
44+
public:
45+
void Compute(const framework::ExecutionContext &ctx) const override {
46+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
47+
"It must use GPUPlace.");
48+
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
49+
const float momentum = ctx.Attr<float>("momentum");
50+
const bool is_test = ctx.Attr<bool>("is_test");
51+
const std::string tensor_format_str =
52+
ctx.Attr<std::string>("tensor_format");
53+
const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str);
54+
55+
// Get the size for each dimension.
56+
// NCHW [batch_size, in_channels, in_height, in_width]
57+
const auto *x = ctx.Input<Tensor>("X");
58+
const auto &x_dims = x->dims();
59+
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
60+
"The Input dim size should be between 3 and 5");
61+
int N, C, H, W, D;
62+
ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D);
63+
64+
// ------------------- cudnn descriptors ---------------------
65+
cudnnTensorDescriptor_t data_desc_;
66+
cudnnTensorDescriptor_t bn_param_desc_;
67+
cudnnBatchNormMode_t mode_;
68+
69+
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
70+
CUDNN_ENFORCE(
71+
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
72+
73+
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
74+
LOG(ERROR) << "Provided epsilon is smaller than "
75+
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
76+
<< "CUDNN_BN_MIN_EPSILON instead.";
77+
}
78+
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
79+
#if CUDNN_VERSION_MIN(7, 0, 0)
80+
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
81+
#else
82+
mode_ = CUDNN_BATCHNORM_SPATIAL;
83+
#endif
84+
85+
VLOG(1) << "Setting descriptors.";
86+
std::vector<int> dims;
87+
std::vector<int> strides;
88+
if (tensor_format == TensorFormat::NCHW) {
89+
dims = {N, C, H, W, D};
90+
strides = {C * H * W * D, H * W * D, W * D, D, 1};
91+
} else {
92+
dims = {N, C, H, W, D};
93+
strides = {H * W * D * C, 1, W * D * C, D * C, C};
94+
}
95+
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
96+
data_desc_, CudnnDataType<T>::type,
97+
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
98+
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
99+
bn_param_desc_, data_desc_, mode_));
100+
101+
const auto *scale = ctx.Input<Tensor>("Scale");
102+
const auto *bias = ctx.Input<Tensor>("Bias");
103+
104+
auto *y = ctx.Output<Tensor>("Y");
105+
auto *mean_out = ctx.Output<Tensor>("MeanOut");
106+
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
107+
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
108+
auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
109+
110+
// alloc memory
111+
y->mutable_data<T>(ctx.GetPlace());
112+
mean_out->mutable_data<T>(ctx.GetPlace());
113+
variance_out->mutable_data<T>(ctx.GetPlace());
114+
saved_mean->mutable_data<T>(ctx.GetPlace());
115+
saved_variance->mutable_data<T>(ctx.GetPlace());
116+
117+
math::SetConstant<platform::GPUPlace, T> functor;
118+
functor(ctx.device_context(), saved_mean, 0);
119+
functor(ctx.device_context(), saved_variance, 0);
120+
// FIXME(qiao) should not set zero self
121+
functor(ctx.device_context(), mean_out, 0);
122+
functor(ctx.device_context(), variance_out, 0);
123+
124+
auto handle = ctx.cuda_device_context().cudnn_handle();
125+
126+
// Now, depending on whether we are running test or not, we have two paths.
127+
if (is_test) {
128+
// only when test we use input to do computation.
129+
const auto *est_mean = ctx.Input<Tensor>("Mean");
130+
const auto *est_var = ctx.Input<Tensor>("Variance");
131+
// Run inference mode.
132+
PADDLE_ENFORCE_EQ(est_mean->dims().size(), 1UL);
133+
PADDLE_ENFORCE_EQ(est_var->dims().size(), 1UL);
134+
PADDLE_ENFORCE_EQ(est_mean->dims()[0], C);
135+
PADDLE_ENFORCE_EQ(est_var->dims()[0], C);
136+
137+
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardInference(
138+
handle,
139+
// Note: PERSISTENT not implemented for inference
140+
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
141+
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
142+
data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
143+
bn_param_desc_, scale->template data<T>(), bias->template data<T>(),
144+
est_mean->template data<T>(), est_var->template data<T>(), epsilon));
145+
} else {
146+
// Run training mode.
147+
// obtain running mean and running inv var, and see if we need to
148+
// initialize them.
149+
double this_factor = 1. - momentum;
150+
151+
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(
152+
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
153+
data_desc_, x->template data<T>(), data_desc_,
154+
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
155+
scale->template data<T>(), bias->template data<T>(), this_factor,
156+
mean_out->template mutable_data<T>(ctx.GetPlace()),
157+
variance_out->template mutable_data<T>(ctx.GetPlace()), epsilon,
158+
saved_mean->template mutable_data<T>(ctx.GetPlace()),
159+
saved_variance->template mutable_data<T>(ctx.GetPlace())));
160+
}
161+
162+
// clean when exit.
163+
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
164+
CUDNN_ENFORCE(
165+
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
166+
}
167+
};
168+
169+
template <typename T>
170+
class BatchNormGradKernel<platform::GPUPlace, T>
171+
: public framework::OpKernel<T> {
172+
public:
173+
void Compute(const framework::ExecutionContext &ctx) const override {
174+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
175+
"It must use GPUPlace.");
176+
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
177+
const std::string tensor_format_str =
178+
ctx.Attr<std::string>("tensor_format");
179+
const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str);
180+
const auto *x = ctx.Input<Tensor>("X");
181+
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
182+
const auto *scale = ctx.Input<Tensor>("Scale");
183+
184+
const auto &x_dims = x->dims();
185+
186+
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
187+
"The Input dim size should be between 3 and 5");
188+
int N, C, H, W, D;
189+
ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D);
190+
191+
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
192+
PADDLE_ENFORCE_EQ(scale->dims()[0], C);
193+
194+
// ------------------- cudnn descriptors ---------------------
195+
cudnnTensorDescriptor_t data_desc_;
196+
cudnnTensorDescriptor_t bn_param_desc_;
197+
cudnnBatchNormMode_t mode_;
198+
199+
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
200+
CUDNN_ENFORCE(
201+
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
202+
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
203+
LOG(ERROR) << "Provided epsilon is smaller than "
204+
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
205+
<< "CUDNN_BN_MIN_EPSILON instead.";
206+
}
207+
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
208+
#if CUDNN_VERSION_MIN(7, 0, 0)
209+
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
210+
#else
211+
mode_ = CUDNN_BATCHNORM_SPATIAL;
212+
#endif
213+
214+
std::vector<int> dims = {N, C, H, W, D};
215+
std::vector<int> strides = {H * W * C * D, 1, W * D * C, D * C, C};
216+
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
217+
data_desc_, CudnnDataType<T>::type,
218+
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
219+
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
220+
bn_param_desc_, data_desc_, mode_));
221+
222+
// init output
223+
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
224+
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
225+
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
226+
227+
d_x->mutable_data<T>(ctx.GetPlace());
228+
d_scale->mutable_data<T>(ctx.GetPlace());
229+
d_bias->mutable_data<T>(ctx.GetPlace());
230+
231+
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
232+
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
233+
const void *saved_mean_data = saved_mean->template data<T>();
234+
const void *saved_var_data = saved_var->template data<T>();
235+
236+
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
237+
ctx.cuda_device_context().cudnn_handle(), mode_,
238+
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
239+
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), data_desc_,
240+
x->template data<T>(), data_desc_, d_y->template data<T>(), data_desc_,
241+
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
242+
scale->template data<T>(),
243+
d_scale->template mutable_data<T>(ctx.GetPlace()),
244+
d_bias->template mutable_data<T>(ctx.GetPlace()), epsilon,
245+
saved_mean_data, saved_var_data));
246+
247+
// clean when exit.
248+
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
249+
CUDNN_ENFORCE(
250+
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
251+
}
252+
};
253+
254+
} // namespace operators
255+
} // namespace paddle
256+
257+
namespace ops = paddle::operators;
258+
REGISTER_OP_GPU_KERNEL(batch_norm,
259+
ops::BatchNormKernel<paddle::platform::GPUPlace, float>);
260+
REGISTER_OP_GPU_KERNEL(
261+
batch_norm_grad,
262+
ops::BatchNormGradKernel<paddle::platform::GPUPlace, float>);

paddle/platform/cudnn_helper.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,47 @@ limitations under the License. */
2222
namespace paddle {
2323
namespace platform {
2424

25+
inline const char* cudnnGetErrorString(cudnnStatus_t status) {
26+
switch (status) {
27+
case CUDNN_STATUS_SUCCESS:
28+
return "CUDNN_STATUS_SUCCESS";
29+
case CUDNN_STATUS_NOT_INITIALIZED:
30+
return "CUDNN_STATUS_NOT_INITIALIZED";
31+
case CUDNN_STATUS_ALLOC_FAILED:
32+
return "CUDNN_STATUS_ALLOC_FAILED";
33+
case CUDNN_STATUS_BAD_PARAM:
34+
return "CUDNN_STATUS_BAD_PARAM";
35+
case CUDNN_STATUS_INTERNAL_ERROR:
36+
return "CUDNN_STATUS_INTERNAL_ERROR";
37+
case CUDNN_STATUS_INVALID_VALUE:
38+
return "CUDNN_STATUS_INVALID_VALUE";
39+
case CUDNN_STATUS_ARCH_MISMATCH:
40+
return "CUDNN_STATUS_ARCH_MISMATCH";
41+
case CUDNN_STATUS_MAPPING_ERROR:
42+
return "CUDNN_STATUS_MAPPING_ERROR";
43+
case CUDNN_STATUS_EXECUTION_FAILED:
44+
return "CUDNN_STATUS_EXECUTION_FAILED";
45+
case CUDNN_STATUS_NOT_SUPPORTED:
46+
return "CUDNN_STATUS_NOT_SUPPORTED";
47+
case CUDNN_STATUS_LICENSE_ERROR:
48+
return "CUDNN_STATUS_LICENSE_ERROR";
49+
default:
50+
return "Unknown cudnn error number";
51+
}
52+
}
53+
54+
#define CUDNN_VERSION_MIN(major, minor, patch) \
55+
(CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch)))
56+
57+
#define CUDNN_ENFORCE(condition) \
58+
do { \
59+
cudnnStatus_t status = condition; \
60+
if (status != CUDNN_STATUS_SUCCESS) { \
61+
VLOG(1) << ::paddle::platform::cudnnGetErrorString(status); \
62+
PADDLE_THROW("cuDNN call failed"); \
63+
} \
64+
} while (false)
65+
2566
enum class DataLayout {
2667
kNHWC,
2768
kNCHW,
@@ -40,12 +81,30 @@ template <>
4081
class CudnnDataType<float> {
4182
public:
4283
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
84+
typedef const float ScalingParamType;
85+
static ScalingParamType* kOne() {
86+
static ScalingParamType v = 1.0;
87+
return &v;
88+
}
89+
static ScalingParamType* kZero() {
90+
static ScalingParamType v = 0.0;
91+
return &v;
92+
}
4393
};
4494

4595
template <>
4696
class CudnnDataType<double> {
4797
public:
4898
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
99+
typedef const double ScalingParamType;
100+
static ScalingParamType* kOne() {
101+
static ScalingParamType v = 1.0;
102+
return &v;
103+
}
104+
static ScalingParamType* kZero() {
105+
static ScalingParamType v = 0.0;
106+
return &v;
107+
}
49108
};
50109

51110
inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) {

paddle/platform/dynload/cudnn.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ extern void* cudnn_dso_handle;
8383
__macro(cudnnDestroyConvolutionDescriptor); \
8484
__macro(cudnnSetConvolutionNdDescriptor); \
8585
__macro(cudnnGetConvolutionNdDescriptor); \
86+
__macro(cudnnDeriveBNTensorDescriptor); \
8687
__macro(cudnnCreate); \
8788
__macro(cudnnDestroy); \
8889
__macro(cudnnSetStream); \

0 commit comments

Comments
 (0)