1
- /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
1
+ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2
2
3
3
Licensed under the Apache License, Version 2.0 (the "License");
4
4
you may not use this file except in compliance with the License.
@@ -11,14 +11,151 @@ distributed under the License is distributed on an "AS IS" BASIS,
11
11
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
- #define EIGEN_USE_GPU
15
14
15
+ #include < algorithm>
16
+ #include " cub/cub.cuh"
16
17
#include " paddle/fluid/operators/norm_op.h"
17
18
19
+ namespace paddle {
20
+ namespace operators {
21
+
22
+ __device__ __forceinline__ float square_root (float x) { return sqrtf (x); }
23
+
24
+ __device__ __forceinline__ double square_root (double x) { return sqrt (x); }
25
+
26
+ template <typename T, int BlockDim>
27
+ __global__ void Normalize (const T* x, const int pre,
28
+ const int axis_n, // dim in axis
29
+ const int post, const T eps, T* y, T* out_norm) {
30
+ typedef cub::BlockReduce<T, BlockDim> BlockReduce;
31
+ __shared__ typename BlockReduce::TempStorage temp_storage;
32
+ int num = pre * post;
33
+ for (int i = blockIdx .x ; i < num; i += gridDim .x ) {
34
+ int base = (i / post) * post * axis_n + (i % post);
35
+
36
+ T sum = 0.0 ;
37
+ __shared__ T norm;
38
+ for (int j = threadIdx .x ; j < axis_n; j += blockDim .x ) {
39
+ const T x_ij = x[base + j * post];
40
+ sum += x_ij * x_ij;
41
+ }
42
+ T reduce_result = BlockReduce (temp_storage).Sum (sum);
43
+
44
+ if (threadIdx .x == 0 ) {
45
+ norm = square_root (reduce_result + eps);
46
+ out_norm[i] = norm;
47
+ }
48
+ __syncthreads ();
49
+ for (int j = threadIdx .x ; j < axis_n; j += blockDim .x ) {
50
+ const int index = base + j * post;
51
+ y[index] = x[index] / norm;
52
+ }
53
+ }
54
+ }
55
+
56
+ template <typename DeviceContext, typename T>
57
+ class NormCUDAKernel : public framework ::OpKernel<T> {
58
+ public:
59
+ void Compute (const framework::ExecutionContext& ctx) const override {
60
+ auto * in_x = ctx.Input <framework::Tensor>(" X" );
61
+ auto * out_y = ctx.Output <framework::Tensor>(" Out" );
62
+ auto * out_norm = ctx.Output <framework::Tensor>(" Norm" );
63
+ const T* x = in_x->data <T>();
64
+ T* y = out_y->mutable_data <T>(ctx.GetPlace ());
65
+ T* norm = out_norm->mutable_data <T>(ctx.GetPlace ());
66
+
67
+ auto xdim = in_x->dims ();
68
+ auto ndim = out_norm->dims ();
69
+ int axis = ctx.Attr <int >(" axis" );
70
+ T eps = static_cast <T>(ctx.Attr <float >(" epsilon" ));
71
+ if (axis < 0 ) axis = xdim.size () + axis;
72
+ int pre, n, post;
73
+ GetDims (xdim, axis, &pre, &n, &post);
74
+
75
+ auto & dev_ctx = ctx.cuda_device_context ();
76
+
77
+ const int block = 512 ;
78
+ int max_threads = dev_ctx.GetMaxPhysicalThreadCount ();
79
+ const int max_blocks = std::max (max_threads / block, 1 );
80
+ int grid = std::min (max_blocks, pre * post);
81
+ Normalize<T, block><<<grid, block, 0 , dev_ctx.stream()>>> (x, pre, n, post,
82
+ eps, y, norm);
83
+ }
84
+ };
85
+
86
+ template <typename T, int BlockDim>
87
+ __global__ void NormalizeGradient (const T* x, const T* x_norm, const T* y_grad,
88
+ const int pre, const int axis_n,
89
+ const int post, T* x_grad) {
90
+ typedef cub::BlockReduce<T, BlockDim> BlockReduce;
91
+ __shared__ typename BlockReduce::TempStorage temp_storage_sum;
92
+ int num = pre * post;
93
+ for (int i = blockIdx .x ; i < num; i += gridDim .x ) {
94
+ T sum = 0.0 ;
95
+ __shared__ T row_sum;
96
+ __shared__ T row_sqrt_norm;
97
+ __shared__ T row_norm;
98
+
99
+ auto base = (i / post) * post * axis_n + (i % post);
100
+
101
+ for (int j = threadIdx .x ; j < axis_n; j += blockDim .x ) {
102
+ int index = base + j * post;
103
+ sum += x[index] * y_grad[index];
104
+ }
105
+ T reduce_result = BlockReduce (temp_storage_sum).Sum (sum);
106
+
107
+ if (threadIdx .x == 0 ) {
108
+ row_sum = reduce_result;
109
+ row_sqrt_norm = x_norm[i];
110
+ row_norm = row_sqrt_norm * row_sqrt_norm;
111
+ }
112
+ __syncthreads ();
113
+ for (int j = threadIdx .x ; j < axis_n; j += blockDim .x ) {
114
+ int index = base + j * post;
115
+ const T x_ij = x[index];
116
+ const T dy_ij = y_grad[index];
117
+ x_grad[index] = (dy_ij - x_ij * row_sum / row_norm) / row_sqrt_norm;
118
+ }
119
+ }
120
+ }
121
+
122
+ template <typename DeviceContext, typename T, typename AttrType = T>
123
+ class NormGradCUDAKernel : public framework ::OpKernel<T> {
124
+ public:
125
+ void Compute (const framework::ExecutionContext& ctx) const override {
126
+ auto * in_x = ctx.Input <framework::Tensor>(" X" );
127
+ auto * in_norm = ctx.Input <framework::Tensor>(" Norm" );
128
+ auto * in_dy = ctx.Input <framework::Tensor>(framework::GradVarName (" Out" ));
129
+ auto * out_dx = ctx.Output <framework::Tensor>(framework::GradVarName (" X" ));
130
+ T* dx = out_dx->mutable_data <T>(ctx.GetPlace ());
131
+ const T* x = in_x->data <T>();
132
+ const T* x_norm = in_norm->data <T>();
133
+ const T* dy = in_dy->data <T>();
134
+
135
+ auto xdim = in_x->dims ();
136
+ int axis = ctx.Attr <int >(" axis" );
137
+ if (axis < 0 ) axis = xdim.size () + axis;
138
+ int pre, n, post;
139
+ GetDims (xdim, axis, &pre, &n, &post);
140
+
141
+ auto & dev_ctx = ctx.cuda_device_context ();
142
+
143
+ const int block = 512 ;
144
+ int max_threads = dev_ctx.GetMaxPhysicalThreadCount ();
145
+ const int max_blocks = std::max (max_threads / block, 1 );
146
+ int grid = std::min (max_blocks, pre * post);
147
+ NormalizeGradient<T, block><<<grid, block, 0 , dev_ctx.stream()>>> (
148
+ x, x_norm, dy, pre, n, post, dx);
149
+ }
150
+ };
151
+
152
+ } // namespace operators
153
+ } // namespace paddle
154
+
18
155
namespace ops = paddle::operators;
19
156
using CUDA = paddle::platform::CUDADeviceContext;
20
157
21
- REGISTER_OP_CUDA_KERNEL (norm, ops::NormKernel <CUDA, float >,
22
- ops::NormKernel <CUDA, double >);
23
- REGISTER_OP_CUDA_KERNEL (norm_grad, ops::NormGradKernel <CUDA, float >,
24
- ops::NormGradKernel <CUDA, double >);
158
+ REGISTER_OP_CUDA_KERNEL (norm, ops::NormCUDAKernel <CUDA, float >,
159
+ ops::NormCUDAKernel <CUDA, double >);
160
+ REGISTER_OP_CUDA_KERNEL (norm_grad, ops::NormGradCUDAKernel <CUDA, float >,
161
+ ops::NormGradCUDAKernel <CUDA, double >);
0 commit comments