|
| 1 | +/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "paddle/operators/math/unpooling.h" |
| 16 | +#include "paddle/platform/cuda_helper.h" |
| 17 | + |
| 18 | +namespace paddle { |
| 19 | +namespace operators { |
| 20 | +namespace math { |
| 21 | +template <typename T> |
| 22 | +__global__ void KernelUnpool2dMax(const int nthreads, const T* input_data, |
| 23 | + const int* indices_data, |
| 24 | + const int input_height, const int input_width, |
| 25 | + const int channels, T* output_data, |
| 26 | + const int output_height, |
| 27 | + const int output_width) { |
| 28 | + int in_n_stride = input_height * input_width * channels; |
| 29 | + int in_c_stride = input_height * input_width; |
| 30 | + int out_n_stride = output_height * output_width * channels; |
| 31 | + int out_c_stride = output_height * output_width; |
| 32 | + int index = blockIdx.x * blockDim.x + threadIdx.x; |
| 33 | + int offset = blockDim.x * gridDim.x; |
| 34 | + for (int i = index; i < nthreads; i += offset) { |
| 35 | + int bidx = i / in_n_stride; |
| 36 | + int boffset = i % in_n_stride; |
| 37 | + int cidx = boffset / in_c_stride; |
| 38 | + int out_offset = bidx * out_n_stride + cidx * out_c_stride; |
| 39 | + int out_index = indices_data[i]; |
| 40 | + PADDLE_ASSERT(out_index < out_c_stride); |
| 41 | + output_data[out_offset + out_index] = input_data[i]; |
| 42 | + } |
| 43 | +} |
| 44 | +template <typename T> |
| 45 | +__global__ void KernelUnpool2dMaxGrad( |
| 46 | + const int nthreads, const T* input_data, const int* indices_data, |
| 47 | + const int input_height, const int input_width, const int channels, |
| 48 | + const T* output_data, const T* output_grad, const int output_height, |
| 49 | + const int output_width, T* input_grad) { |
| 50 | + int in_n_stride = input_height * input_width * channels; |
| 51 | + int in_c_stride = input_height * input_width; |
| 52 | + int out_n_stride = output_height * output_width * channels; |
| 53 | + int out_c_stride = output_height * output_width; |
| 54 | + int index = blockIdx.x * blockDim.x + threadIdx.x; |
| 55 | + int offset = blockDim.x * gridDim.x; |
| 56 | + for (int i = index; i < nthreads; i += offset) { |
| 57 | + int bidx = i / in_n_stride; |
| 58 | + int boffset = i % in_n_stride; |
| 59 | + int cidx = boffset / in_c_stride; |
| 60 | + int out_offset = bidx * out_n_stride + cidx * out_c_stride; |
| 61 | + int out_index = indices_data[i]; |
| 62 | + PADDLE_ASSERT(out_index < out_c_stride); |
| 63 | + input_grad[i] = output_grad[out_offset + out_index]; |
| 64 | + } |
| 65 | +} |
| 66 | +/* |
| 67 | + * All tensors are in NCHW format. |
| 68 | + */ |
| 69 | +template <typename T> |
| 70 | +class Unpool2dMaxFunctor<platform::GPUPlace, T> { |
| 71 | + public: |
| 72 | + void operator()(const platform::DeviceContext& context, |
| 73 | + const framework::Tensor& input, |
| 74 | + const framework::Tensor& indices, framework::Tensor* output) { |
| 75 | + const int batch_size = input.dims()[0]; |
| 76 | + const int input_height = input.dims()[2]; |
| 77 | + const int input_width = input.dims()[3]; |
| 78 | + const int output_channels = output->dims()[1]; |
| 79 | + const int output_height = output->dims()[2]; |
| 80 | + const int output_width = output->dims()[3]; |
| 81 | + const T* input_data = input.data<T>(); |
| 82 | + const int* indices_data = indices.data<int>(); |
| 83 | + T* output_data = output->mutable_data<T>(context.GetPlace()); |
| 84 | + int threads = 1024; |
| 85 | + int grid = (input.numel() + threads - 1) / threads; |
| 86 | + KernelUnpool2dMax< |
| 87 | + T><<<grid, threads, 0, |
| 88 | + reinterpret_cast<const platform::CUDADeviceContext&>(context) |
| 89 | + .stream()>>>(input.numel(), input_data, indices_data, |
| 90 | + input_height, input_width, output_channels, |
| 91 | + output_data, output_height, output_width); |
| 92 | + } |
| 93 | +}; |
| 94 | +/* |
| 95 | + * All tensors are in NCHW format. |
| 96 | + */ |
| 97 | +template <typename T> |
| 98 | +class Unpool2dMaxGradFunctor<platform::GPUPlace, T> { |
| 99 | + public: |
| 100 | + void operator()(const platform::DeviceContext& context, |
| 101 | + const framework::Tensor& input, |
| 102 | + const framework::Tensor& indices, |
| 103 | + const framework::Tensor& output, |
| 104 | + const framework::Tensor& output_grad, |
| 105 | + framework::Tensor* input_grad) { |
| 106 | + const int batch_size = input.dims()[0]; |
| 107 | + const int input_height = input.dims()[2]; |
| 108 | + const int input_width = input.dims()[3]; |
| 109 | + const int output_channels = output.dims()[1]; |
| 110 | + const int output_height = output.dims()[2]; |
| 111 | + const int output_width = output.dims()[3]; |
| 112 | + const T* input_data = input.data<T>(); |
| 113 | + const int* indices_data = indices.data<int>(); |
| 114 | + const T* output_data = output.data<T>(); |
| 115 | + const T* output_grad_data = output_grad.data<T>(); |
| 116 | + T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); |
| 117 | + int threads = 1024; |
| 118 | + int grid = (input.numel() + threads - 1) / threads; |
| 119 | + KernelUnpool2dMaxGrad< |
| 120 | + T><<<grid, threads, 0, |
| 121 | + reinterpret_cast<const platform::CUDADeviceContext&>(context) |
| 122 | + .stream()>>>(input.numel(), input_data, indices_data, |
| 123 | + input_height, input_width, output_channels, |
| 124 | + output_data, output_grad_data, output_height, |
| 125 | + output_width, input_grad_data); |
| 126 | + } |
| 127 | +}; |
| 128 | +template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>; |
| 129 | +template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>; |
| 130 | +template class Unpool2dMaxFunctor<platform::GPUPlace, float>; |
| 131 | +template class Unpool2dMaxFunctor<platform::GPUPlace, double>; |
| 132 | +} // namespace math |
| 133 | +} // namespace operators |
| 134 | +} // namespace paddle |
0 commit comments