Skip to content

Commit 446d54f

Browse files
committed
update
1 parent ffa22a5 commit 446d54f

File tree

3 files changed

+27
-22
lines changed

3 files changed

+27
-22
lines changed

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
8383

8484
protected:
8585
framework::OpKernelType GetExpectedKernelType(
86-
const ExecutionContext &ctx) const override {
86+
const framework::ExecutionContext &ctx) const override {
8787
auto input_data_type =
8888
framework::ToDataType(ctx.Input<Tensor>("X")->type());
8989
// For float or float16 input tensor, the type of the scale, bias, mean,

paddle/fluid/operators/batch_norm_op.cu.cc

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ using DataLayout = framework::DataLayout;
2828
template <typename T>
2929
using CudnnDataType = platform::CudnnDataType<T>;
3030
template <typename T>
31-
using ScalingParamType = typename CudnnDataType<T>::ScalingParamType;
31+
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
3232

3333
void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
3434
int *N, int *C, int *H, int *W, int *D) {
@@ -122,15 +122,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
122122

123123
// alloc memory
124124
y->mutable_data<T>(ctx.GetPlace());
125-
mean_out->mutable_data<ScalingParamType<T>>(ctx.GetPlace());
126-
variance_out->mutable_data<ScalingParamType<T>>(ctx.GetPlace());
127-
saved_mean->mutable_data<ScalingParamType<T>>(ctx.GetPlace());
128-
saved_variance->mutable_data<ScalingParamType<T>>(ctx.GetPlace());
125+
mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
126+
variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
127+
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
128+
saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
129129

130130
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
131-
math::SetConstant<platform::CUDADeviceContext, ScalingParamType<T>> functor;
132-
functor(dev_ctx, saved_mean, static_cast<ScalingParamType<T>>(0));
133-
functor(dev_ctx, saved_variance, static_cast<ScalingParamType<T>>(0));
131+
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
132+
functor;
133+
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
134+
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
134135

135136
auto handle = dev_ctx.cudnn_handle();
136137

@@ -151,10 +152,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
151152
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
152153
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
153154
data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
154-
bn_param_desc_, scale->template data<ScalingParamType<T>>(),
155-
bias->template data<ScalingParamType<T>>(),
156-
est_mean->template data<ScalingParamType<T>>(),
157-
est_var->template data<ScalingParamType<T>>(), epsilon));
155+
bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
156+
bias->template data<BatchNormParamType<T>>(),
157+
est_mean->template data<BatchNormParamType<T>>(),
158+
est_var->template data<BatchNormParamType<T>>(), epsilon));
158159
} else {
159160
// Run training mode.
160161
// obtain running mean and running inv var, and see if we need to
@@ -165,14 +166,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
165166
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
166167
data_desc_, x->template data<T>(), data_desc_,
167168
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
168-
scale->template data<ScalingParamType<T>>(),
169-
bias->template data<ScalingParamType<T>>(), this_factor,
170-
mean_out->template mutable_data<ScalingParamType<T>>(ctx.GetPlace()),
171-
variance_out->template mutable_data<ScalingParamType<T>>(
169+
scale->template data<BatchNormParamType<T>>(),
170+
bias->template data<BatchNormParamType<T>>(), this_factor,
171+
mean_out->template mutable_data<BatchNormParamType<T>>(
172172
ctx.GetPlace()),
173-
epsilon, saved_mean->template mutable_data<ScalingParamType<T>>(
173+
variance_out->template mutable_data<BatchNormParamType<T>>(
174+
ctx.GetPlace()),
175+
epsilon, saved_mean->template mutable_data<BatchNormParamType<T>>(
174176
ctx.GetPlace()),
175-
saved_variance->template mutable_data<ScalingParamType<T>>(
177+
saved_variance->template mutable_data<BatchNormParamType<T>>(
176178
ctx.GetPlace())));
177179
}
178180

paddle/fluid/platform/cudnn_helper.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ class CudnnDataType<float16> {
8686
public:
8787
static const cudnnDataType_t type = CUDNN_DATA_HALF;
8888
// The scaling param type is float for HALF and FLOAT tensors
89-
typedef const float ScalingParamType;
89+
using ScalingParamType = const float;
90+
using BatchNormParamType = float;
9091
static ScalingParamType* kOne() {
9192
static ScalingParamType v = 1.0;
9293
return &v;
@@ -101,7 +102,8 @@ template <>
101102
class CudnnDataType<float> {
102103
public:
103104
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
104-
typedef const float ScalingParamType;
105+
using ScalingParamType = const float;
106+
using BatchNormParamType = float;
105107
static ScalingParamType* kOne() {
106108
static ScalingParamType v = 1.0;
107109
return &v;
@@ -116,7 +118,8 @@ template <>
116118
class CudnnDataType<double> {
117119
public:
118120
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
119-
typedef const double ScalingParamType;
121+
using ScalingParamType = const double;
122+
using BatchNormParamType = double;
120123
static ScalingParamType* kOne() {
121124
static ScalingParamType v = 1.0;
122125
return &v;

0 commit comments

Comments
 (0)