|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "paddle/operators/roi_pool_op.h" |
| 16 | +#include "paddle/platform/cuda_helper.h" |
| 17 | + |
| 18 | +namespace paddle { |
| 19 | +namespace operators { |
| 20 | + |
| 21 | +using Tensor = framework::Tensor; |
| 22 | + |
| 23 | +static constexpr int kNumCUDAThreads = 512; |
| 24 | +static constexpr int kNumMaxinumNumBlocks = 4096; |
| 25 | +static constexpr int kROISize = 5; |
| 26 | + |
| 27 | +static inline int NumBlocks(const int N) { |
| 28 | + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, |
| 29 | + kNumMaxinumNumBlocks); |
| 30 | +} |
| 31 | + |
| 32 | + template <typename T> |
| 33 | + __global__ void GPUROIPoolForward( |
| 34 | + const int nthreads, const T* input_data, const int64_t* input_rois, |
| 35 | + const float spatial_scale, const int channels, const int height, |
| 36 | + const int width, const int pooled_height, const int pooled_width, |
| 37 | + T* output_data, int64_t* argmax_data) { |
| 38 | + int index = blockIdx.x * blockDim.x + threadIdx.x; |
| 39 | + int offset = blockDim.x * gridDim.x; |
| 40 | + for (size_t i = index; i < nthreads; i += offset) { |
| 41 | + int pw = index % pooled_width; |
| 42 | + int ph = (index / pooled_width) % pooled_height; |
| 43 | + int c = (index / pooled_width / pooled_height) % channels; |
| 44 | + int n = index / pooled_width / pooled_height / channels; |
| 45 | + |
| 46 | + const int64_t* offset_input_rois = input_rois + n * kROISize; |
| 47 | + int roi_batch_ind = offset_input_rois[0]; |
| 48 | + int roi_start_w = round(offset_input_rois[1] * spatial_scale); |
| 49 | + int roi_start_h = round(offset_input_rois[2] * spatial_scale); |
| 50 | + int roi_end_w = round(offset_input_rois[3] * spatial_scale); |
| 51 | + int roi_end_h = round(offset_input_rois[4] * spatial_scale); |
| 52 | + |
| 53 | + int roi_width = max(roi_end_w - roi_start_w + 1, 1); |
| 54 | + int roi_height = max(roi_end_h - roi_start_h + 1, 1); |
| 55 | + T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); |
| 56 | + T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); |
| 57 | + |
| 58 | + int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h)); |
| 59 | + int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w)); |
| 60 | + int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h)); |
| 61 | + int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w)); |
| 62 | + |
| 63 | + hstart = min(max(hstart + roi_start_h, 0), height); |
| 64 | + hend = min(max(hend + roi_start_h, 0), height); |
| 65 | + wstart = min(max(wstart + roi_start_w, 0), width); |
| 66 | + wend = min(max(wend + roi_start_w, 0), width); |
| 67 | + bool is_empty = (hend <= hstart) || (wend <= wstart); |
| 68 | + |
| 69 | + T maxval = is_empty ? 0 : -std::numeric_limits<T>::max(); |
| 70 | + int maxidx = -1; |
| 71 | + const T* offset_input_data = |
| 72 | + input_data + (roi_batch_ind * channels + c) * height * width; |
| 73 | + for (int h = hstart; h < hend; ++h) { |
| 74 | + for (int w = wstart; w < wend; ++w) { |
| 75 | + int input_data_index = h * width + w; |
| 76 | + if (offset_input_data[input_data_index] > maxval) { |
| 77 | + maxval = offset_input_data[input_data_index]; |
| 78 | + maxidx = input_data_index; |
| 79 | + } |
| 80 | + } |
| 81 | + } |
| 82 | + output_data[index] = maxval; |
| 83 | + if (argmax_data) { |
| 84 | + argmax_data[index] = maxidx; |
| 85 | + } |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | +template <typename T> |
| 90 | +__global__ void GPUROIPoolBackward( |
| 91 | + const int nthreads, |
| 92 | + const int64_t* input_rois, |
| 93 | + const T* output_grad, |
| 94 | + const int64_t* argmax_data, |
| 95 | + const int num_rois, |
| 96 | + const float spatial_scale, |
| 97 | + const int channels, |
| 98 | + const int height, |
| 99 | + const int width, |
| 100 | + const int pooled_height, |
| 101 | + const int pooled_width, |
| 102 | + T* input_grad) { |
| 103 | + int index = blockIdx.x * blockDim.x + threadIdx.x; |
| 104 | + int offset = blockDim.x * gridDim.x; |
| 105 | + for (int i = index; i < nthreads; i += offset) { |
| 106 | + int pw = index % pooled_width; |
| 107 | + int ph = (index / pooled_width) % pooled_height; |
| 108 | + int c = (index / pooled_width / pooled_height) % channels; |
| 109 | + int n = index / pooled_width / pooled_height / channels; |
| 110 | + |
| 111 | + const int64_t* offset_input_rois = input_rois + n * kROISize; |
| 112 | + int roi_batch_ind = offset_input_rois[0]; |
| 113 | + int input_offset = (roi_batch_ind * channels + c) * height * width; |
| 114 | + int output_offset = (n * channels + c) * pooled_height * pooled_width; |
| 115 | + const T* offset_output_grad = output_grad + output_offset; |
| 116 | + T* offset_input_grad = input_grad + input_offset; |
| 117 | + const int64_t* offset_argmax_data = argmax_data + output_offset; |
| 118 | + |
| 119 | + int argmax = offset_argmax_data[ph * pooled_width + pw]; |
| 120 | + if (argmax != -1) { |
| 121 | + platform::CudaAtomicAdd(offset_input_grad + argmax, |
| 122 | + static_cast<T>(offset_output_grad[ph * pooled_width + pw])); |
| 123 | + } |
| 124 | + } |
| 125 | + } |
| 126 | + |
| 127 | + |
| 128 | +template <typename Place, typename T> |
| 129 | +class GPUROIPoolOpKernel : public framework::OpKernel<T> { |
| 130 | + public: |
| 131 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 132 | + auto* in = ctx.Input<Tensor>("X"); |
| 133 | + auto* rois = ctx.Input<Tensor>("ROIs"); |
| 134 | + auto* out = ctx.Output<Tensor>("Out"); |
| 135 | + auto* argmax = ctx.Output<Tensor>("Argmax"); |
| 136 | + |
| 137 | + auto pooled_height = ctx.Attr<int>("pooled_height"); |
| 138 | + auto pooled_width = ctx.Attr<int>("pooled_width"); |
| 139 | + auto spatial_scale = ctx.Attr<float>("spatial_scale"); |
| 140 | + |
| 141 | + auto in_dims = in->dims(); |
| 142 | + auto in_stride = framework::stride(in_dims); |
| 143 | + int channels = in_dims[1]; |
| 144 | + int height = in_dims[2]; |
| 145 | + int width = in_dims[3]; |
| 146 | + |
| 147 | + size_t rois_num = rois->dims()[0]; |
| 148 | + if (rois_num== 0) return; |
| 149 | + |
| 150 | + int output_size = out->numel(); |
| 151 | + int blocks = NumBlocks(output_size); |
| 152 | + int threads = kNumCUDAThreads; |
| 153 | + |
| 154 | + GPUROIPoolForward<T> |
| 155 | + <<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( |
| 156 | + output_size, |
| 157 | + in->data<T>(), |
| 158 | + rois->data<int64_t>(), |
| 159 | + spatial_scale, |
| 160 | + channels, |
| 161 | + height, |
| 162 | + width, |
| 163 | + pooled_height, |
| 164 | + pooled_width, |
| 165 | + out->mutable_data<T>(ctx.GetPlace()), |
| 166 | + argmax->mutable_data<int64_t>(ctx.GetPlace())); |
| 167 | + } |
| 168 | +}; |
| 169 | + |
| 170 | +template <typename Place, typename T> |
| 171 | +class GPUROIPoolGradOpKernel : public framework::OpKernel<T> { |
| 172 | + public: |
| 173 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 174 | + auto* in = ctx.Input<Tensor>("X"); |
| 175 | + auto* rois = ctx.Input<Tensor>("ROIs"); |
| 176 | + auto* argmax = ctx.Input<Tensor>("Argmax"); |
| 177 | + |
| 178 | + auto* out_grad = |
| 179 | + ctx.Input<Tensor>(framework::GradVarName("Out")); |
| 180 | + auto* x_grad = |
| 181 | + ctx.Output<Tensor>(framework::GradVarName("X")); |
| 182 | + |
| 183 | + auto pooled_height = ctx.Attr<int>("pooled_height"); |
| 184 | + auto pooled_width = ctx.Attr<int>("pooled_width"); |
| 185 | + auto spatial_scale = ctx.Attr<float>("spatial_scale"); |
| 186 | + |
| 187 | + size_t rois_num = rois->dims()[0]; |
| 188 | + int channels = in->dims()[1]; |
| 189 | + int height = in->dims()[2]; |
| 190 | + int width = in->dims()[3]; |
| 191 | + |
| 192 | + if (x_grad) { |
| 193 | + x_grad->mutable_data<T>(ctx.GetPlace()); |
| 194 | + math::SetConstant<Place, T> set_zero; |
| 195 | + set_zero(ctx.device_context(), x_grad, static_cast<T>(0)); |
| 196 | + |
| 197 | + int output_grad_size = out_grad->numel(); |
| 198 | + int blocks = NumBlocks(output_grad_size); |
| 199 | + int threads = kNumCUDAThreads; |
| 200 | + |
| 201 | + if (output_grad_size > 0) { |
| 202 | + GPUROIPoolBackward<T> |
| 203 | + <<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( |
| 204 | + output_grad_size, |
| 205 | + rois->data<int64_t>(), |
| 206 | + out_grad->data<T>(), |
| 207 | + argmax->data<int64_t>(), |
| 208 | + rois_num, |
| 209 | + spatial_scale, |
| 210 | + channels, |
| 211 | + height, |
| 212 | + width, |
| 213 | + pooled_height, |
| 214 | + pooled_width, |
| 215 | + x_grad->mutable_data<T>(ctx.GetPlace())); |
| 216 | + } |
| 217 | + } |
| 218 | + } |
| 219 | +}; |
| 220 | + |
| 221 | +} // namespace operators |
| 222 | +} // namespace paddle |
| 223 | + |
| 224 | +namespace ops = paddle::operators; |
| 225 | +REGISTER_OP_GPU_KERNEL( |
| 226 | + roi_pool, |
| 227 | + ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, float>, |
| 228 | + ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, double>); |
| 229 | +REGISTER_OP_GPU_KERNEL( |
| 230 | + roi_pool_grad, |
| 231 | + ops::GPUROIPoolGradOpKernel<paddle::platform::GPUPlace, float>, |
| 232 | + ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, double>); |
0 commit comments