@@ -19,156 +19,110 @@ limitations under the License. */
19
19
namespace paddle {
20
20
namespace operators {
21
21
22
- template <typename DeviceContext, typename T, typename AttrType = T>
22
+ inline void GetDims (const framework::DDim& dim, int axis, int * pre, int * n,
23
+ int * post) {
24
+ *pre = 1 ;
25
+ *post = 1 ;
26
+ *n = dim[axis];
27
+ for (int i = 0 ; i < axis; ++i) {
28
+ (*pre) *= dim[i];
29
+ }
30
+ for (int i = axis + 1 ; i < dim.size (); ++i) {
31
+ (*post) *= dim[i];
32
+ }
33
+ }
34
+
35
+ template <typename DeviceContext, typename T>
23
36
class NormKernel : public framework ::OpKernel<T> {
24
37
public:
25
- void Compute (const framework::ExecutionContext& context) const override {
26
- const framework::Tensor* in_x = context.Input <framework::Tensor>(" X" );
27
- const framework::Tensor* scale = context.Input <framework::Tensor>(" Scale" );
28
- auto * out = context.Output <framework::Tensor>(" Out" );
29
- auto epsilon = static_cast <T>(context.Attr <AttrType>(" epsilon" ));
30
- out->mutable_data <T>(context.GetPlace ());
31
- int batch_size = in_x->dims ()[0 ];
32
- int channels = in_x->dims ()[1 ];
33
- int height = in_x->dims ()[2 ];
34
- int width = in_x->dims ()[3 ];
35
- int fea_len = height * width;
36
- auto * place =
37
- context.template device_context <DeviceContext>().eigen_device ();
38
- auto x =
39
- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
40
- *in_x, framework::make_ddim ({batch_size, fea_len * channels}));
41
- // get square
42
- framework::Tensor x_square;
43
- x_square.mutable_data <T>(in_x->dims (), context.GetPlace ());
44
- auto x_square_eigen =
45
- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
46
- x_square, framework::make_ddim ({batch_size, fea_len * channels}));
47
- x_square_eigen.device (*place) = x.square ();
48
- auto scale_eigen =
49
- framework::EigenVector<T, Eigen::RowMajor, Eigen::DenseIndex>::Flatten (
50
- *scale);
51
- for (int n = 0 ; n < batch_size; ++n) {
52
- framework::Tensor in_x_batch = in_x->Slice (n, n + 1 );
53
- auto in_x_batch_eigen =
54
- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
55
- in_x_batch, framework::make_ddim ({channels, fea_len}));
56
- framework::Tensor x_square_batch = x_square.Slice (n, n + 1 );
57
- auto x_square_batch_eigen =
58
- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
59
- x_square_batch, framework::make_ddim ({channels, fea_len}));
60
- framework::Tensor out_batch = out->Slice (n, n + 1 );
61
- auto out_batch_eigen =
62
- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
63
- out_batch, framework::make_ddim ({channels, fea_len}));
64
- framework::Tensor tmp_tensor;
65
- tmp_tensor.mutable_data <T>(framework::make_ddim ({1 , fea_len}),
66
- context.GetPlace ());
67
- auto tmp = framework::EigenVector<T, Eigen::RowMajor,
68
- Eigen::DenseIndex>::Flatten (tmp_tensor);
69
- // get colsum and sqrt , inverse
70
- auto dim = Eigen::array<int , 1 >({{0 }});
71
- tmp.device (*place) = x_square_batch_eigen.sum (dim);
72
- tmp.device (*place) = (tmp + epsilon).sqrt ().inverse ();
73
- Eigen::array<int , 2 > broadcast_dim_col;
74
- broadcast_dim_col[1 ] = 1 ;
75
- broadcast_dim_col[0 ] = channels;
76
- out_batch_eigen.device (*place) =
77
- in_x_batch_eigen * (tmp.broadcast (broadcast_dim_col));
78
- Eigen::array<int , 2 > broadcast_dim_row;
79
- broadcast_dim_row[1 ] = fea_len;
80
- broadcast_dim_row[0 ] = 1 ;
81
- out_batch_eigen.device (*place) =
82
- out_batch_eigen * (scale_eigen.broadcast (broadcast_dim_row));
83
- }
38
+ void Compute (const framework::ExecutionContext& ctx) const override {
39
+ auto * in_x = ctx.Input <framework::Tensor>(" X" );
40
+ auto * out_y = ctx.Output <framework::Tensor>(" Out" );
41
+ auto * out_norm = ctx.Output <framework::Tensor>(" Norm" );
42
+ out_y->mutable_data <T>(ctx.GetPlace ());
43
+ out_norm->mutable_data <T>(ctx.GetPlace ());
44
+
45
+ auto xdim = in_x->dims ();
46
+ auto ndim = out_norm->dims ();
47
+ T eps = static_cast <T>(ctx.Attr <float >(" epsilon" ));
48
+ int axis = ctx.Attr <int >(" axis" );
49
+ if (axis < 0 ) axis = xdim.size () + axis;
50
+ int pre, n, post;
51
+ GetDims (xdim, axis, &pre, &n, &post);
52
+
53
+ auto * place = ctx.template device_context <DeviceContext>().eigen_device ();
54
+
55
+ Eigen::DSizes<int , 3 > shape (pre, n, post);
56
+ Eigen::DSizes<int , 2 > norm_shape (pre, post);
57
+
58
+ auto x_e = framework::EigenVector<T>::Flatten (*in_x);
59
+ auto y_e = framework::EigenVector<T>::Flatten (*out_y);
60
+ auto norm_e = framework::EigenVector<T>::Flatten (*out_norm);
61
+ auto x = x_e.reshape (shape);
62
+ auto y = y_e.reshape (shape);
63
+ auto norm = norm_e.reshape (norm_shape);
64
+
65
+ Eigen::DSizes<int , 1 > rdim (1 );
66
+ // y = x / sqrt((sum(x * x) + epsilon))
67
+ // norm = sqrt(sum(x * x) + epsilon)
68
+ auto sum = x.pow (2 ).sum (rdim) + eps;
69
+ norm.device (*place) = sum.sqrt ();
70
+ // y = x / norm
71
+ Eigen::DSizes<int , 3 > rshape (pre, 1 , post);
72
+ Eigen::DSizes<int , 3 > bcast (1 , n, 1 );
73
+ y.device (*place) = x / norm.reshape (rshape).broadcast (bcast);
84
74
}
85
75
};
86
76
template <typename DeviceContext, typename T, typename AttrType = T>
87
77
class NormGradKernel : public framework ::OpKernel<T> {
88
78
public:
89
- void Compute (const framework::ExecutionContext& context) const override {
90
- const framework::Tensor* in_x = context.Input <framework::Tensor>(" X" );
91
- const framework::Tensor* scale = context.Input <framework::Tensor>(" Scale" );
92
- const framework::Tensor* out_grad =
93
- context.Input <framework::Tensor>(framework::GradVarName (" Out" ));
94
- auto epsilon = static_cast <T>(context.Attr <AttrType>(" epsilon" ));
95
- framework::Tensor* in_x_grad =
96
- context.Output <framework::Tensor>(framework::GradVarName (" X" ));
97
- in_x_grad->mutable_data <T>(context.GetPlace ());
98
- int batch_size = in_x->dims ()[0 ];
99
- int channels = in_x->dims ()[1 ];
100
- int height = in_x->dims ()[2 ];
101
- int width = in_x->dims ()[3 ];
102
- int fea_len = height * width;
103
- auto * place =
104
- context.template device_context <DeviceContext>().eigen_device ();
105
-
106
- auto scale_eigen =
107
- framework::EigenVector<T, Eigen::RowMajor, Eigen::DenseIndex>::Flatten (
108
- *scale);
109
- auto x =
110
- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
111
- *in_x, framework::make_ddim ({batch_size, fea_len * channels}));
112
- // get square
113
- framework::Tensor x_square;
114
- x_square.mutable_data <T>(in_x->dims (), context.GetPlace ());
115
- auto x_square_eigen =
116
- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
117
- x_square, framework::make_ddim ({batch_size, fea_len * channels}));
118
- x_square_eigen.device (*place) = x.square ();
119
-
120
- for (int n = 0 ; n < batch_size; ++n) {
121
- framework::Tensor in_x_batch = in_x->Slice (n, n + 1 );
122
- auto in_x_batch_eigen =
123
- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
124
- in_x_batch, framework::make_ddim ({channels, fea_len}));
125
- framework::Tensor in_g_batch = in_x_grad->Slice (n, n + 1 );
126
- auto in_g_batch_eigen =
127
- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
128
- in_g_batch, framework::make_ddim ({channels, fea_len}));
129
- framework::Tensor x_square_batch = x_square.Slice (n, n + 1 );
130
- auto x_square_batch_eigen =
131
- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
132
- x_square_batch, framework::make_ddim ({channels, fea_len}));
133
- framework::Tensor outg_batch = out_grad->Slice (n, n + 1 );
134
- auto outg_batch_eigen =
135
- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
136
- outg_batch, framework::make_ddim ({channels, fea_len}));
137
-
138
- framework::Tensor tmp_tensor;
139
- tmp_tensor.mutable_data <T>(framework::make_ddim ({1 , fea_len}),
140
- context.GetPlace ());
141
- auto tmp_eigen =
142
- framework::EigenVector<T, Eigen::RowMajor,
143
- Eigen::DenseIndex>::Flatten (tmp_tensor);
144
- auto dim = Eigen::array<int , 1 >({{0 }});
145
- tmp_eigen.device (*place) = (in_x_batch_eigen * outg_batch_eigen).sum (dim);
146
- framework::Tensor norm_tmp_tensor;
147
- norm_tmp_tensor.mutable_data <T>(framework::make_ddim ({1 , fea_len}),
148
- context.GetPlace ());
149
- auto norm_tmp_eigen =
150
- framework::EigenVector<T, Eigen::RowMajor,
151
- Eigen::DenseIndex>::Flatten (norm_tmp_tensor);
152
- norm_tmp_eigen.device (*place) =
153
- (x_square_batch_eigen.sum (dim) + epsilon).sqrt ();
154
- Eigen::array<int , 2 > broadcast_dim_col;
155
- broadcast_dim_col[1 ] = 1 ;
156
- broadcast_dim_col[0 ] = channels;
157
- in_g_batch_eigen.device (*place) =
158
- in_x_batch_eigen * tmp_eigen.broadcast (broadcast_dim_col);
159
- in_g_batch_eigen.device (*place) =
160
- in_g_batch_eigen /
161
- (norm_tmp_eigen * norm_tmp_eigen).broadcast (broadcast_dim_col);
162
- in_g_batch_eigen.device (*place) = outg_batch_eigen - in_g_batch_eigen;
163
- // outg_batch_eigen + (in_g_batch_eigen * -1);
164
- in_g_batch_eigen.device (*place) =
165
- in_g_batch_eigen / norm_tmp_eigen.broadcast (broadcast_dim_col);
166
- Eigen::array<int , 2 > broadcast_dim_row;
167
- broadcast_dim_row[1 ] = fea_len;
168
- broadcast_dim_row[0 ] = 1 ;
169
- in_g_batch_eigen.device (*place) =
170
- in_g_batch_eigen * (scale_eigen.broadcast (broadcast_dim_row));
171
- }
79
+ void Compute (const framework::ExecutionContext& ctx) const override {
80
+ auto * in_x = ctx.Input <framework::Tensor>(" X" );
81
+ auto * in_norm = ctx.Input <framework::Tensor>(" Norm" );
82
+ auto * in_dy = ctx.Input <framework::Tensor>(framework::GradVarName (" Out" ));
83
+ auto * out_dx = ctx.Output <framework::Tensor>(framework::GradVarName (" X" ));
84
+ out_dx->mutable_data <T>(ctx.GetPlace ());
85
+
86
+ auto xdim = in_x->dims ();
87
+ int axis = ctx.Attr <int >(" axis" );
88
+ if (axis < 0 ) axis = xdim.size () + axis;
89
+ int pre, n, post;
90
+ GetDims (xdim, axis, &pre, &n, &post);
91
+
92
+ auto * place = ctx.template device_context <DeviceContext>().eigen_device ();
93
+
94
+ auto x_e = framework::EigenVector<T>::Flatten (*in_x);
95
+ auto dy_e = framework::EigenVector<T>::Flatten (*in_dy);
96
+ auto norm_e = framework::EigenVector<T>::Flatten (*in_norm);
97
+ auto dx_e = framework::EigenVector<T>::Flatten (*out_dx);
98
+
99
+ Eigen::DSizes<int , 3 > shape (pre, n, post);
100
+ Eigen::DSizes<int , 2 > norm_shape (pre, post);
101
+ auto x = x_e.reshape (shape);
102
+ auto dy = dy_e.reshape (shape);
103
+ auto norm = norm_e.reshape (norm_shape);
104
+ auto dx = dx_e.reshape (shape);
105
+
106
+ framework::Tensor rsum;
107
+ rsum.mutable_data <T>({pre, post}, ctx.GetPlace ());
108
+ auto sum = framework::EigenTensor<T, 2 >::From (rsum);
109
+
110
+ Eigen::DSizes<int , 1 > rdim (1 );
111
+ Eigen::DSizes<int , 3 > bcast (1 , n, 1 );
112
+ Eigen::DSizes<int , 3 > rshape (pre, 1 , post);
113
+
114
+ // dx = ( dy/sqrt(sum(x*x)) ) * [1 - x*sum(x) / (sum(x*x) + e)]
115
+ // = [dy - dy * x * sum(x) / (sum(x*x) + e)] / sqrt(sum(x*x))
116
+ // = [dy - x * sum(x*dy) / (sum(x*x) + e)] / sqrt(sum(x*x))
117
+ // 1. sum = sum(x*dy)
118
+ sum.device (*place) = (x * dy).sum (rdim);
119
+ // 2. dx = x * sum
120
+ dx.device (*place) = sum.reshape (rshape).broadcast (bcast) * x;
121
+ // 3. dx / (sum(x*x) + e)
122
+ // where, norm.pow(2) = sum(x*x) + e, which is calculated in forward.
123
+ dx.device (*place) = dx / norm.pow (2 ).broadcast (bcast);
124
+ // 4. [dy - dx] / sqrt(sum(x*x))
125
+ dx.device (*place) = (dy - dx) / norm.broadcast (bcast);
172
126
}
173
127
};
174
128
} // namespace operators
0 commit comments