|
| 1 | +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "paddle/fluid/operators/detection/density_prior_box_op.h" |
| 16 | + |
| 17 | +namespace paddle { |
| 18 | +namespace operators { |
| 19 | + |
| 20 | +template <typename T> |
| 21 | +static __device__ inline T Clip(T in) { |
| 22 | + return min(max(in, 0.), 1.); |
| 23 | +} |
| 24 | + |
| 25 | +template <typename T> |
| 26 | +static __global__ void GenDensityPriorBox( |
| 27 | + const int height, const int width, const int im_height, const int im_width, |
| 28 | + const T offset, const T step_width, const T step_height, |
| 29 | + const int num_priors, const T* ratios_shift, bool is_clip, const T var_xmin, |
| 30 | + const T var_ymin, const T var_xmax, const T var_ymax, T* out, T* var) { |
| 31 | + int gidx = blockIdx.x * blockDim.x + threadIdx.x; |
| 32 | + int gidy = blockIdx.y * blockDim.y + threadIdx.y; |
| 33 | + int step_x = blockDim.x * gridDim.x; |
| 34 | + int step_y = blockDim.y * gridDim.y; |
| 35 | + |
| 36 | + const T* width_ratio = ratios_shift; |
| 37 | + const T* height_ratio = ratios_shift + num_priors; |
| 38 | + const T* width_shift = ratios_shift + 2 * num_priors; |
| 39 | + const T* height_shift = ratios_shift + 3 * num_priors; |
| 40 | + |
| 41 | + for (int j = gidy; j < height; j += step_y) { |
| 42 | + for (int i = gidx; i < width * num_priors; i += step_x) { |
| 43 | + int h = j; |
| 44 | + int w = i / num_priors; |
| 45 | + int k = i % num_priors; |
| 46 | + |
| 47 | + T center_x = (w + offset) * step_width; |
| 48 | + T center_y = (h + offset) * step_height; |
| 49 | + |
| 50 | + T center_x_temp = center_x + width_shift[k]; |
| 51 | + T center_y_temp = center_y + height_shift[k]; |
| 52 | + |
| 53 | + T box_width_ratio = width_ratio[k] / 2.; |
| 54 | + T box_height_ratio = height_ratio[k] / 2.; |
| 55 | + |
| 56 | + T xmin = max((center_x_temp - box_width_ratio) / im_width, 0.); |
| 57 | + T ymin = max((center_y_temp - box_height_ratio) / im_height, 0.); |
| 58 | + T xmax = min((center_x_temp + box_width_ratio) / im_width, 1.); |
| 59 | + T ymax = min((center_y_temp + box_height_ratio) / im_height, 1.); |
| 60 | + |
| 61 | + int out_offset = (j * width * num_priors + i) * 4; |
| 62 | + out[out_offset] = is_clip ? Clip<T>(xmin) : xmin; |
| 63 | + out[out_offset + 1] = is_clip ? Clip<T>(ymin) : ymin; |
| 64 | + out[out_offset + 2] = is_clip ? Clip<T>(xmax) : xmax; |
| 65 | + out[out_offset + 3] = is_clip ? Clip<T>(ymax) : ymax; |
| 66 | + |
| 67 | + var[out_offset] = var_xmin; |
| 68 | + var[out_offset + 1] = var_ymin; |
| 69 | + var[out_offset + 2] = var_xmax; |
| 70 | + var[out_offset + 3] = var_ymax; |
| 71 | + } |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +template <typename T> |
| 76 | +class DensityPriorBoxOpCUDAKernel : public framework::OpKernel<T> { |
| 77 | + public: |
| 78 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 79 | + auto* input = ctx.Input<paddle::framework::Tensor>("Input"); |
| 80 | + auto* image = ctx.Input<paddle::framework::Tensor>("Image"); |
| 81 | + auto* boxes = ctx.Output<paddle::framework::Tensor>("Boxes"); |
| 82 | + auto* vars = ctx.Output<paddle::framework::Tensor>("Variances"); |
| 83 | + |
| 84 | + auto variances = ctx.Attr<std::vector<float>>("variances"); |
| 85 | + auto is_clip = ctx.Attr<bool>("clip"); |
| 86 | + |
| 87 | + auto fixed_sizes = ctx.Attr<std::vector<float>>("fixed_sizes"); |
| 88 | + auto fixed_ratios = ctx.Attr<std::vector<float>>("fixed_ratios"); |
| 89 | + auto densities = ctx.Attr<std::vector<int>>("densities"); |
| 90 | + |
| 91 | + T step_w = static_cast<T>(ctx.Attr<float>("step_w")); |
| 92 | + T step_h = static_cast<T>(ctx.Attr<float>("step_h")); |
| 93 | + T offset = static_cast<T>(ctx.Attr<float>("offset")); |
| 94 | + |
| 95 | + auto img_width = image->dims()[3]; |
| 96 | + auto img_height = image->dims()[2]; |
| 97 | + |
| 98 | + auto feature_width = input->dims()[3]; |
| 99 | + auto feature_height = input->dims()[2]; |
| 100 | + |
| 101 | + T step_width, step_height; |
| 102 | + if (step_w == 0 || step_h == 0) { |
| 103 | + step_width = static_cast<T>(img_width) / feature_width; |
| 104 | + step_height = static_cast<T>(img_height) / feature_height; |
| 105 | + } else { |
| 106 | + step_width = step_w; |
| 107 | + step_height = step_h; |
| 108 | + } |
| 109 | + |
| 110 | + int num_priors = 0; |
| 111 | + for (size_t i = 0; i < densities.size(); ++i) { |
| 112 | + num_priors += (fixed_ratios.size()) * (pow(densities[i], 2)); |
| 113 | + } |
| 114 | + int step_average = static_cast<int>((step_width + step_height) * 0.5); |
| 115 | + |
| 116 | + framework::Tensor h_temp; |
| 117 | + T* tdata = h_temp.mutable_data<T>({num_priors * 4}, platform::CPUPlace()); |
| 118 | + int idx = 0; |
| 119 | + for (size_t s = 0; s < fixed_sizes.size(); ++s) { |
| 120 | + auto fixed_size = fixed_sizes[s]; |
| 121 | + int density = densities[s]; |
| 122 | + for (size_t r = 0; r < fixed_ratios.size(); ++r) { |
| 123 | + float ar = fixed_ratios[r]; |
| 124 | + int shift = step_average / density; |
| 125 | + float box_width_ratio = fixed_size * sqrt(ar); |
| 126 | + float box_height_ratio = fixed_size / sqrt(ar); |
| 127 | + for (int di = 0; di < density; ++di) { |
| 128 | + for (int dj = 0; dj < density; ++dj) { |
| 129 | + float center_x_temp = shift / 2. + dj * shift - step_average / 2.; |
| 130 | + float center_y_temp = shift / 2. + di * shift - step_average / 2.; |
| 131 | + tdata[idx] = box_width_ratio; |
| 132 | + tdata[num_priors + idx] = box_height_ratio; |
| 133 | + tdata[2 * num_priors + idx] = center_x_temp; |
| 134 | + tdata[3 * num_priors + idx] = center_y_temp; |
| 135 | + idx++; |
| 136 | + } |
| 137 | + } |
| 138 | + } |
| 139 | + } |
| 140 | + |
| 141 | + boxes->mutable_data<T>(ctx.GetPlace()); |
| 142 | + vars->mutable_data<T>(ctx.GetPlace()); |
| 143 | + |
| 144 | + framework::Tensor d_temp; |
| 145 | + framework::TensorCopySync(h_temp, ctx.GetPlace(), &d_temp); |
| 146 | + |
| 147 | + // At least use 32 threads, at most 512 threads. |
| 148 | + // blockx is multiple of 32. |
| 149 | + int blockx = std::min(((feature_width * num_priors + 31) >> 5) << 5, 512L); |
| 150 | + int gridx = (feature_width * num_priors + blockx - 1) / blockx; |
| 151 | + dim3 threads(blockx, 1); |
| 152 | + dim3 grids(gridx, feature_height); |
| 153 | + |
| 154 | + auto stream = |
| 155 | + ctx.template device_context<platform::CUDADeviceContext>().stream(); |
| 156 | + GenDensityPriorBox<T><<<grids, threads, 0, stream>>>( |
| 157 | + feature_height, feature_width, img_height, img_width, offset, |
| 158 | + step_width, step_height, num_priors, d_temp.data<T>(), is_clip, |
| 159 | + variances[0], variances[1], variances[2], variances[3], |
| 160 | + boxes->data<T>(), vars->data<T>()); |
| 161 | + } |
| 162 | +}; // namespace operators |
| 163 | + |
| 164 | +} // namespace operators |
| 165 | +} // namespace paddle |
| 166 | + |
| 167 | +namespace ops = paddle::operators; |
| 168 | +REGISTER_OP_CUDA_KERNEL(density_prior_box, |
| 169 | + ops::DensityPriorBoxOpCUDAKernel<float>, |
| 170 | + ops::DensityPriorBoxOpCUDAKernel<double>); |
0 commit comments