Skip to content

Commit 19fd071

Browse files
authored
Make the normalization operator more general and fix bug in l2_normalize. (#11348)
* Add normalization operator. 1. Refine the raw norm_op and let it more general to support to normalize Tensor along any axis. 2. There is a bug in l2_normalize API, which lacks sqrt after `reduce_sum`. 3. Use norm_op to refine the l2_normalize API. 4. Fix bug in test_normalization_wrapper.py.
1 parent f15504e commit 19fd071

File tree

7 files changed

+193
-271
lines changed

7 files changed

+193
-271
lines changed

paddle/fluid/operators/norm_op.cc

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,40 +16,34 @@ limitations under the License. */
1616
namespace paddle {
1717
namespace operators {
1818

19-
template <typename AttrType>
2019
class NormOpMaker : public framework::OpProtoAndCheckerMaker {
2120
public:
2221
void Make() override {
23-
AddInput(
24-
"X",
25-
"(Tensor) The input tensor of norm operator. "
26-
"The format of input tensor is NCHW. Where N is batch size, C is the "
27-
"number of channels, H and W is the height and width of feature.");
28-
AddInput("Scale",
29-
"(Tensor) The input tensor of norm operator. "
30-
"The format of input tensor is C * 1.");
31-
AddAttr<AttrType>("epsilon",
32-
"(float, default 1e-10) Constant "
33-
"for numerical stability.")
22+
AddInput("X", "(Tensor) A tensor of rank >= axis.");
23+
AddAttr<int>("axis",
24+
"The axis on which to apply normalization. If axis < 0, "
25+
"the dimension to normalization is rank(X) + axis. -1 is "
26+
"the last dimension.");
27+
AddAttr<float>("epsilon",
28+
"(float, default 1e-10) The epsilon value is used "
29+
"to avoid division by zero.")
3430
.SetDefault(1.0e-10f);
35-
AddOutput("Out",
36-
"(Tensor) The output tensor of norm operator."
37-
"N * M."
38-
"M = C * H * W");
31+
AddOutput("Norm",
32+
"(Tensor) A tensor saved the `sqrt(sum(x) + epsion)` will "
33+
"be used in backward kernel.")
34+
.AsIntermediate();
35+
AddOutput("Out", "(Tensor) A tensor of the same shape as X.");
3936
AddComment(R"DOC(
40-
"Input shape: $(N, C, H, W)$
41-
Scale shape: $(C, 1)$
42-
Output shape: $(N, C, H, W)$
43-
Where
44-
forward
45-
$$
46-
[\frac {x_{1}}{\sqrt{\sum{x_{i}^{2}}}} \frac {x_{2}}{\sqrt{\sum{x_{i}^{2}}}} \frac {x_{3}}{\sqrt{\sum{x_{i}^{2}}}} \cdot \cdot \cdot \frac {x_{n}}{\sqrt{\sum{x_{i}^{2}}}}]
47-
$$
48-
backward
49-
$$
50-
\frac{\frac{\mathrm{d}L }{\mathrm{d}y_{1}} - \frac {x_{1}\sum {\frac{\mathrm{d} L}{\mathrm{d} y_{j}}}x_{j}}{\sum x_{j}^{2}} }{\sqrt{\sum{x_{j}^{2}}}}
51-
$$
52-
)DOC");
37+
38+
Given a tensor, apply 2-normalization along the provided axis.
39+
40+
$$
41+
y = \frac{x}{ \sqrt{\sum {x^2} + epsion }}
42+
$$
43+
44+
where, $\sum {x^2}$ is calculated along the `axis` dimension.
45+
46+
)DOC");
5347
}
5448
};
5549

@@ -58,15 +52,15 @@ class NormOp : public framework::OperatorWithKernel {
5852
using framework::OperatorWithKernel::OperatorWithKernel;
5953
void InferShape(framework::InferShapeContext* ctx) const override {
6054
PADDLE_ENFORCE(ctx->HasInput("X"),
61-
"Input(X) of NormOp"
62-
"should not be null.");
63-
PADDLE_ENFORCE(ctx->HasInput("Scale"),
64-
"Input(Scale) of NormOp"
65-
"should not be null.");
55+
"Input(X) of NormOp should not be null.");
6656
PADDLE_ENFORCE(ctx->HasOutput("Out"),
6757
"Output(Out) of NormOp should not be null.");
68-
auto in_x_dims = ctx->GetInputDim("X");
69-
ctx->SetOutputDim("Out", in_x_dims);
58+
auto xdim = ctx->GetInputDim("X");
59+
ctx->SetOutputDim("Out", xdim);
60+
int axis = ctx->Attrs().Get<int>("axis");
61+
if (axis < 0) axis = xdim.size() + axis;
62+
xdim[axis] = 1;
63+
ctx->SetOutputDim("Norm", xdim);
7064
}
7165
};
7266

@@ -84,12 +78,12 @@ class NormOpGrad : public framework::OperatorWithKernel {
8478
} // namespace paddle
8579

8680
namespace ops = paddle::operators;
87-
REGISTER_OPERATOR(norm, ops::NormOp, ops::NormOpMaker<float>,
81+
using CPU = paddle::platform::CPUDeviceContext;
82+
83+
REGISTER_OPERATOR(norm, ops::NormOp, ops::NormOpMaker,
8884
paddle::framework::DefaultGradOpDescMaker<true>);
8985
REGISTER_OPERATOR(norm_grad, ops::NormOpGrad);
90-
REGISTER_OP_CPU_KERNEL(
91-
norm, ops::NormKernel<paddle::platform::CPUDeviceContext, float>,
92-
ops::NormKernel<paddle::platform::CPUDeviceContext, double, float>);
93-
REGISTER_OP_CPU_KERNEL(
94-
norm_grad, ops::NormGradKernel<paddle::platform::CPUDeviceContext, float>,
95-
ops::NormGradKernel<paddle::platform::CPUDeviceContext, double, float>);
86+
REGISTER_OP_CPU_KERNEL(norm, ops::NormKernel<CPU, float>,
87+
ops::NormKernel<CPU, double>);
88+
REGISTER_OP_CPU_KERNEL(norm_grad, ops::NormGradKernel<CPU, float>,
89+
ops::NormGradKernel<CPU, double>);

paddle/fluid/operators/norm_op.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ limitations under the License. */
1616
#include "paddle/fluid/operators/norm_op.h"
1717

1818
namespace ops = paddle::operators;
19-
REGISTER_OP_CUDA_KERNEL(
20-
norm, ops::NormKernel<paddle::platform::CUDADeviceContext, float>,
21-
ops::NormKernel<paddle::platform::CUDADeviceContext, double, float>);
22-
REGISTER_OP_CUDA_KERNEL(
23-
norm_grad, ops::NormGradKernel<paddle::platform::CUDADeviceContext, float>,
24-
ops::NormGradKernel<paddle::platform::CUDADeviceContext, double, float>);
19+
using CUDA = paddle::platform::CUDADeviceContext;
20+
21+
REGISTER_OP_CUDA_KERNEL(norm, ops::NormKernel<CUDA, float>,
22+
ops::NormKernel<CUDA, double>);
23+
REGISTER_OP_CUDA_KERNEL(norm_grad, ops::NormGradKernel<CUDA, float>,
24+
ops::NormGradKernel<CUDA, double>);

paddle/fluid/operators/norm_op.h

Lines changed: 97 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -19,156 +19,110 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace operators {
2121

22-
template <typename DeviceContext, typename T, typename AttrType = T>
22+
inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n,
23+
int* post) {
24+
*pre = 1;
25+
*post = 1;
26+
*n = dim[axis];
27+
for (int i = 0; i < axis; ++i) {
28+
(*pre) *= dim[i];
29+
}
30+
for (int i = axis + 1; i < dim.size(); ++i) {
31+
(*post) *= dim[i];
32+
}
33+
}
34+
35+
template <typename DeviceContext, typename T>
2336
class NormKernel : public framework::OpKernel<T> {
2437
public:
25-
void Compute(const framework::ExecutionContext& context) const override {
26-
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
27-
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale");
28-
auto* out = context.Output<framework::Tensor>("Out");
29-
auto epsilon = static_cast<T>(context.Attr<AttrType>("epsilon"));
30-
out->mutable_data<T>(context.GetPlace());
31-
int batch_size = in_x->dims()[0];
32-
int channels = in_x->dims()[1];
33-
int height = in_x->dims()[2];
34-
int width = in_x->dims()[3];
35-
int fea_len = height * width;
36-
auto* place =
37-
context.template device_context<DeviceContext>().eigen_device();
38-
auto x =
39-
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
40-
*in_x, framework::make_ddim({batch_size, fea_len * channels}));
41-
// get square
42-
framework::Tensor x_square;
43-
x_square.mutable_data<T>(in_x->dims(), context.GetPlace());
44-
auto x_square_eigen =
45-
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
46-
x_square, framework::make_ddim({batch_size, fea_len * channels}));
47-
x_square_eigen.device(*place) = x.square();
48-
auto scale_eigen =
49-
framework::EigenVector<T, Eigen::RowMajor, Eigen::DenseIndex>::Flatten(
50-
*scale);
51-
for (int n = 0; n < batch_size; ++n) {
52-
framework::Tensor in_x_batch = in_x->Slice(n, n + 1);
53-
auto in_x_batch_eigen =
54-
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
55-
in_x_batch, framework::make_ddim({channels, fea_len}));
56-
framework::Tensor x_square_batch = x_square.Slice(n, n + 1);
57-
auto x_square_batch_eigen =
58-
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
59-
x_square_batch, framework::make_ddim({channels, fea_len}));
60-
framework::Tensor out_batch = out->Slice(n, n + 1);
61-
auto out_batch_eigen =
62-
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
63-
out_batch, framework::make_ddim({channels, fea_len}));
64-
framework::Tensor tmp_tensor;
65-
tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
66-
context.GetPlace());
67-
auto tmp = framework::EigenVector<T, Eigen::RowMajor,
68-
Eigen::DenseIndex>::Flatten(tmp_tensor);
69-
// get colsum and sqrt , inverse
70-
auto dim = Eigen::array<int, 1>({{0}});
71-
tmp.device(*place) = x_square_batch_eigen.sum(dim);
72-
tmp.device(*place) = (tmp + epsilon).sqrt().inverse();
73-
Eigen::array<int, 2> broadcast_dim_col;
74-
broadcast_dim_col[1] = 1;
75-
broadcast_dim_col[0] = channels;
76-
out_batch_eigen.device(*place) =
77-
in_x_batch_eigen * (tmp.broadcast(broadcast_dim_col));
78-
Eigen::array<int, 2> broadcast_dim_row;
79-
broadcast_dim_row[1] = fea_len;
80-
broadcast_dim_row[0] = 1;
81-
out_batch_eigen.device(*place) =
82-
out_batch_eigen * (scale_eigen.broadcast(broadcast_dim_row));
83-
}
38+
void Compute(const framework::ExecutionContext& ctx) const override {
39+
auto* in_x = ctx.Input<framework::Tensor>("X");
40+
auto* out_y = ctx.Output<framework::Tensor>("Out");
41+
auto* out_norm = ctx.Output<framework::Tensor>("Norm");
42+
out_y->mutable_data<T>(ctx.GetPlace());
43+
out_norm->mutable_data<T>(ctx.GetPlace());
44+
45+
auto xdim = in_x->dims();
46+
auto ndim = out_norm->dims();
47+
T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
48+
int axis = ctx.Attr<int>("axis");
49+
if (axis < 0) axis = xdim.size() + axis;
50+
int pre, n, post;
51+
GetDims(xdim, axis, &pre, &n, &post);
52+
53+
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
54+
55+
Eigen::DSizes<int, 3> shape(pre, n, post);
56+
Eigen::DSizes<int, 2> norm_shape(pre, post);
57+
58+
auto x_e = framework::EigenVector<T>::Flatten(*in_x);
59+
auto y_e = framework::EigenVector<T>::Flatten(*out_y);
60+
auto norm_e = framework::EigenVector<T>::Flatten(*out_norm);
61+
auto x = x_e.reshape(shape);
62+
auto y = y_e.reshape(shape);
63+
auto norm = norm_e.reshape(norm_shape);
64+
65+
Eigen::DSizes<int, 1> rdim(1);
66+
// y = x / sqrt((sum(x * x) + epsilon))
67+
// norm = sqrt(sum(x * x) + epsilon)
68+
auto sum = x.pow(2).sum(rdim) + eps;
69+
norm.device(*place) = sum.sqrt();
70+
// y = x / norm
71+
Eigen::DSizes<int, 3> rshape(pre, 1, post);
72+
Eigen::DSizes<int, 3> bcast(1, n, 1);
73+
y.device(*place) = x / norm.reshape(rshape).broadcast(bcast);
8474
}
8575
};
8676
template <typename DeviceContext, typename T, typename AttrType = T>
8777
class NormGradKernel : public framework::OpKernel<T> {
8878
public:
89-
void Compute(const framework::ExecutionContext& context) const override {
90-
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
91-
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale");
92-
const framework::Tensor* out_grad =
93-
context.Input<framework::Tensor>(framework::GradVarName("Out"));
94-
auto epsilon = static_cast<T>(context.Attr<AttrType>("epsilon"));
95-
framework::Tensor* in_x_grad =
96-
context.Output<framework::Tensor>(framework::GradVarName("X"));
97-
in_x_grad->mutable_data<T>(context.GetPlace());
98-
int batch_size = in_x->dims()[0];
99-
int channels = in_x->dims()[1];
100-
int height = in_x->dims()[2];
101-
int width = in_x->dims()[3];
102-
int fea_len = height * width;
103-
auto* place =
104-
context.template device_context<DeviceContext>().eigen_device();
105-
106-
auto scale_eigen =
107-
framework::EigenVector<T, Eigen::RowMajor, Eigen::DenseIndex>::Flatten(
108-
*scale);
109-
auto x =
110-
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
111-
*in_x, framework::make_ddim({batch_size, fea_len * channels}));
112-
// get square
113-
framework::Tensor x_square;
114-
x_square.mutable_data<T>(in_x->dims(), context.GetPlace());
115-
auto x_square_eigen =
116-
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
117-
x_square, framework::make_ddim({batch_size, fea_len * channels}));
118-
x_square_eigen.device(*place) = x.square();
119-
120-
for (int n = 0; n < batch_size; ++n) {
121-
framework::Tensor in_x_batch = in_x->Slice(n, n + 1);
122-
auto in_x_batch_eigen =
123-
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
124-
in_x_batch, framework::make_ddim({channels, fea_len}));
125-
framework::Tensor in_g_batch = in_x_grad->Slice(n, n + 1);
126-
auto in_g_batch_eigen =
127-
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
128-
in_g_batch, framework::make_ddim({channels, fea_len}));
129-
framework::Tensor x_square_batch = x_square.Slice(n, n + 1);
130-
auto x_square_batch_eigen =
131-
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
132-
x_square_batch, framework::make_ddim({channels, fea_len}));
133-
framework::Tensor outg_batch = out_grad->Slice(n, n + 1);
134-
auto outg_batch_eigen =
135-
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
136-
outg_batch, framework::make_ddim({channels, fea_len}));
137-
138-
framework::Tensor tmp_tensor;
139-
tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
140-
context.GetPlace());
141-
auto tmp_eigen =
142-
framework::EigenVector<T, Eigen::RowMajor,
143-
Eigen::DenseIndex>::Flatten(tmp_tensor);
144-
auto dim = Eigen::array<int, 1>({{0}});
145-
tmp_eigen.device(*place) = (in_x_batch_eigen * outg_batch_eigen).sum(dim);
146-
framework::Tensor norm_tmp_tensor;
147-
norm_tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
148-
context.GetPlace());
149-
auto norm_tmp_eigen =
150-
framework::EigenVector<T, Eigen::RowMajor,
151-
Eigen::DenseIndex>::Flatten(norm_tmp_tensor);
152-
norm_tmp_eigen.device(*place) =
153-
(x_square_batch_eigen.sum(dim) + epsilon).sqrt();
154-
Eigen::array<int, 2> broadcast_dim_col;
155-
broadcast_dim_col[1] = 1;
156-
broadcast_dim_col[0] = channels;
157-
in_g_batch_eigen.device(*place) =
158-
in_x_batch_eigen * tmp_eigen.broadcast(broadcast_dim_col);
159-
in_g_batch_eigen.device(*place) =
160-
in_g_batch_eigen /
161-
(norm_tmp_eigen * norm_tmp_eigen).broadcast(broadcast_dim_col);
162-
in_g_batch_eigen.device(*place) = outg_batch_eigen - in_g_batch_eigen;
163-
// outg_batch_eigen + (in_g_batch_eigen * -1);
164-
in_g_batch_eigen.device(*place) =
165-
in_g_batch_eigen / norm_tmp_eigen.broadcast(broadcast_dim_col);
166-
Eigen::array<int, 2> broadcast_dim_row;
167-
broadcast_dim_row[1] = fea_len;
168-
broadcast_dim_row[0] = 1;
169-
in_g_batch_eigen.device(*place) =
170-
in_g_batch_eigen * (scale_eigen.broadcast(broadcast_dim_row));
171-
}
79+
void Compute(const framework::ExecutionContext& ctx) const override {
80+
auto* in_x = ctx.Input<framework::Tensor>("X");
81+
auto* in_norm = ctx.Input<framework::Tensor>("Norm");
82+
auto* in_dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
83+
auto* out_dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
84+
out_dx->mutable_data<T>(ctx.GetPlace());
85+
86+
auto xdim = in_x->dims();
87+
int axis = ctx.Attr<int>("axis");
88+
if (axis < 0) axis = xdim.size() + axis;
89+
int pre, n, post;
90+
GetDims(xdim, axis, &pre, &n, &post);
91+
92+
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
93+
94+
auto x_e = framework::EigenVector<T>::Flatten(*in_x);
95+
auto dy_e = framework::EigenVector<T>::Flatten(*in_dy);
96+
auto norm_e = framework::EigenVector<T>::Flatten(*in_norm);
97+
auto dx_e = framework::EigenVector<T>::Flatten(*out_dx);
98+
99+
Eigen::DSizes<int, 3> shape(pre, n, post);
100+
Eigen::DSizes<int, 2> norm_shape(pre, post);
101+
auto x = x_e.reshape(shape);
102+
auto dy = dy_e.reshape(shape);
103+
auto norm = norm_e.reshape(norm_shape);
104+
auto dx = dx_e.reshape(shape);
105+
106+
framework::Tensor rsum;
107+
rsum.mutable_data<T>({pre, post}, ctx.GetPlace());
108+
auto sum = framework::EigenTensor<T, 2>::From(rsum);
109+
110+
Eigen::DSizes<int, 1> rdim(1);
111+
Eigen::DSizes<int, 3> bcast(1, n, 1);
112+
Eigen::DSizes<int, 3> rshape(pre, 1, post);
113+
114+
// dx = ( dy/sqrt(sum(x*x)) ) * [1 - x*sum(x) / (sum(x*x) + e)]
115+
// = [dy - dy * x * sum(x) / (sum(x*x) + e)] / sqrt(sum(x*x))
116+
// = [dy - x * sum(x*dy) / (sum(x*x) + e)] / sqrt(sum(x*x))
117+
// 1. sum = sum(x*dy)
118+
sum.device(*place) = (x * dy).sum(rdim);
119+
// 2. dx = x * sum
120+
dx.device(*place) = sum.reshape(rshape).broadcast(bcast) * x;
121+
// 3. dx / (sum(x*x) + e)
122+
// where, norm.pow(2) = sum(x*x) + e, which is calculated in forward.
123+
dx.device(*place) = dx / norm.pow(2).broadcast(bcast);
124+
// 4. [dy - dx] / sqrt(sum(x*x))
125+
dx.device(*place) = (dy - dx) / norm.broadcast(bcast);
172126
}
173127
};
174128
} // namespace operators

0 commit comments

Comments
 (0)