|
12 | 12 | See the License for the specific language governing permissions and
|
13 | 13 | limitations under the License. */
|
14 | 14 |
|
15 |
| -#define EIGEN_USE_GPU |
16 | 15 | #include "paddle/operators/lrn_op.h"
|
17 | 16 |
|
18 |
| -namespace ops = paddle::operators; |
| 17 | +namespace paddle { |
| 18 | +namespace operators { |
| 19 | + |
| 20 | +template <typename T> |
| 21 | +__global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C, |
| 22 | + int H, int W, int size, T k, T alpha) { |
| 23 | + const int idx = threadIdx.x + blockIdx.x * blockDim.x; |
| 24 | + if (idx < img_size) { |
| 25 | + const int w = idx % W; |
| 26 | + const int h = (idx / W) % H; |
| 27 | + const int n = idx / W / H; |
| 28 | + const int offset = (n * C * H + h) * W + w; |
| 29 | + |
| 30 | + in += offset; |
| 31 | + mid += offset; |
| 32 | + const int step = H * W; |
| 33 | + const int pre_pad = (size - 1) / 2; |
| 34 | + const int post_pad = size - pre_pad - 1; |
| 35 | + |
| 36 | + T accum = 0; |
| 37 | + int index = 0; |
| 38 | + while (index < C + post_pad) { |
| 39 | + if (index < C) { |
| 40 | + T val = in[index * step]; |
| 41 | + accum += val * val; |
| 42 | + } |
| 43 | + if (index >= size) { |
| 44 | + T val = in[(index - size) * step]; |
| 45 | + accum -= val * val; |
| 46 | + } |
| 47 | + if (index >= post_pad) { |
| 48 | + mid[(index - post_pad) * step] = k + accum * alpha; |
| 49 | + } |
| 50 | + ++index; |
| 51 | + } |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +template <typename T> |
| 56 | +__global__ void KeCMRNormOutput(int input_size, const T* in, const T* mid, |
| 57 | + T negative_beta, T* out) { |
| 58 | + const int index = threadIdx.x + blockIdx.x * blockDim.x; |
| 59 | + if (index < input_size) { |
| 60 | + out[index] = in[index] * pow(mid[index], negative_beta); |
| 61 | + } |
| 62 | +} |
| 63 | + |
| 64 | +template <typename T> |
| 65 | +void CrossMapNormal(const framework::ExecutionContext& ctx, const T* inputs, |
| 66 | + T* outputs, T* mid, int N, int C, int H, int W, int n, T k, |
| 67 | + T alpha, T beta) { |
| 68 | + int img_size = N * H * W; |
| 69 | + const int block_size = 1024; |
| 70 | + int grid_size = (img_size + block_size - 1) / block_size; |
| 71 | + |
| 72 | + KeCMRNormFillScale< |
| 73 | + T><<<grid_size, block_size, 0, ctx.cuda_device_context().stream()>>>( |
| 74 | + img_size, inputs, mid, C, H, W, n, k, alpha); |
| 75 | + |
| 76 | + int input_size = N * H * W * C; |
| 77 | + grid_size = (input_size + block_size - 1) / block_size; |
| 78 | + KeCMRNormOutput< |
| 79 | + T><<<grid_size, block_size, 0, ctx.cuda_device_context().stream()>>>( |
| 80 | + input_size, inputs, mid, -beta, outputs); |
| 81 | +} |
| 82 | + |
| 83 | +template <typename T> |
| 84 | +struct LRNFunctor<platform::GPUPlace, T> { |
| 85 | + void operator()(const framework::ExecutionContext& ctx, |
| 86 | + const framework::Tensor& input, framework::Tensor* out, |
| 87 | + framework::Tensor* mid, int N, int C, int H, int W, int n, |
| 88 | + T k, T alpha, T beta) { |
| 89 | + CrossMapNormal<T>( |
| 90 | + ctx, input.data<T>(), out->mutable_data<T>(ctx.GetPlace()), |
| 91 | + mid->mutable_data<T>(ctx.GetPlace()), N, C, H, W, n, k, alpha, beta); |
| 92 | + } |
| 93 | +}; |
| 94 | + |
| 95 | +template struct LRNFunctor<platform::GPUPlace, float>; |
| 96 | +template struct LRNFunctor<platform::GPUPlace, double>; |
19 | 97 |
|
| 98 | +template <typename T> |
| 99 | +__global__ void KeCMRNormDiff(int img_size, const T* x, const T* out, |
| 100 | + const T* mid, T* x_g, const T* out_g, int C, |
| 101 | + int H, int W, int size, T negative_beta, |
| 102 | + T ratio) { |
| 103 | + const int idx = threadIdx.x + blockIdx.x * blockDim.x; |
| 104 | + if (idx < img_size) { |
| 105 | + const int w = idx % W; |
| 106 | + const int h = (idx / W) % H; |
| 107 | + const int n = idx / W / H; |
| 108 | + const int offset = (n * C * H + h) * W + w; |
| 109 | + x += offset; |
| 110 | + out += offset; |
| 111 | + mid += offset; |
| 112 | + out_g += offset; |
| 113 | + x_g += offset; |
| 114 | + |
| 115 | + const int step = H * W; |
| 116 | + const int pre_pad = size - (size + 1) / 2; |
| 117 | + const int post_pad = size - pre_pad - 1; |
| 118 | + |
| 119 | + int index = 0; |
| 120 | + T accum = 0; |
| 121 | + // TODO(gongwb): optimize this with thread shared array. |
| 122 | + while (index < C + post_pad) { |
| 123 | + if (index < C) { |
| 124 | + x_g[index * step] = 0.0; |
| 125 | + accum += out_g[index * step] * out[index * step] / mid[index * step]; |
| 126 | + } |
| 127 | + if (index >= size) { |
| 128 | + accum -= out_g[(index - size) * step] * out[(index - size) * step] / |
| 129 | + mid[(index - size) * step]; |
| 130 | + } |
| 131 | + if (index >= post_pad) { |
| 132 | + x_g[(index - post_pad) * step] += |
| 133 | + out_g[(index - post_pad) * step] * |
| 134 | + pow(mid[(index - post_pad) * step], negative_beta) - |
| 135 | + ratio * x[(index - post_pad) * step] * accum; |
| 136 | + } |
| 137 | + ++index; |
| 138 | + } |
| 139 | + } |
| 140 | +} |
| 141 | + |
| 142 | +template <typename T> |
| 143 | +void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x, |
| 144 | + const T* out, const T* mid, T* x_g, const T* out_g, |
| 145 | + int N, int C, int H, int W, int n, T alpha, T beta) { |
| 146 | + int img_size = N * H * W; |
| 147 | + |
| 148 | + const int block_size = 1024; |
| 149 | + int grid_size = (img_size + block_size - 1) / block_size; |
| 150 | + |
| 151 | + KeCMRNormDiff< |
| 152 | + T><<<grid_size, block_size, 0, ctx.cuda_device_context().stream()>>>( |
| 153 | + img_size, x, out, mid, x_g, out_g, C, H, W, n, -beta, |
| 154 | + 2.0f * alpha * beta); |
| 155 | +} |
| 156 | + |
| 157 | +template <typename T> |
| 158 | +struct LRNGradFunctor<platform::GPUPlace, T> { |
| 159 | + void operator()(const framework::ExecutionContext& ctx, |
| 160 | + const framework::Tensor& x, const framework::Tensor& out, |
| 161 | + const framework::Tensor& mid, framework::Tensor* x_g, |
| 162 | + const framework::Tensor& out_g, int N, int C, int H, int W, |
| 163 | + int n, T alpha, T beta) { |
| 164 | + CrossMapNormalGrad<T>(ctx, x.data<T>(), out.data<T>(), mid.data<T>(), |
| 165 | + x_g->mutable_data<T>(ctx.GetPlace()), out_g.data<T>(), |
| 166 | + N, C, H, W, n, alpha, beta); |
| 167 | + } |
| 168 | +}; |
| 169 | + |
| 170 | +template struct LRNGradFunctor<platform::GPUPlace, float>; |
| 171 | +template struct LRNGradFunctor<platform::GPUPlace, double>; |
| 172 | +} // namespace operators |
| 173 | +} // namespace paddle |
| 174 | + |
| 175 | +namespace ops = paddle::operators; |
20 | 176 | REGISTER_OP_GPU_KERNEL(lrn, ops::LRNKernel<paddle::platform::GPUPlace, float>);
|
21 | 177 | REGISTER_OP_GPU_KERNEL(lrn_grad,
|
22 | 178 | ops::LRNGradKernel<paddle::platform::GPUPlace, float>);
|
0 commit comments