|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
| 2 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | + you may not use this file except in compliance with the License. |
| 4 | + You may obtain a copy of the License at |
| 5 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | + Unless required by applicable law or agreed to in writing, software |
| 7 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | + See the License for the specific language governing permissions and |
| 10 | + limitations under the License. */ |
| 11 | + |
| 12 | +#include "paddle/fluid/operators/bilinear_interp_op.h" |
| 13 | +#include "paddle/fluid/platform/cuda_helper.h" |
| 14 | + |
| 15 | +namespace paddle { |
| 16 | +namespace operators { |
| 17 | + |
| 18 | +using framework::Tensor; |
| 19 | + |
| 20 | +template <typename T> |
| 21 | +__global__ void KeBilinearInterpFw( |
| 22 | + const T* in, const size_t in_img_h, const size_t in_img_w, |
| 23 | + const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, |
| 24 | + const size_t out_img_w, const size_t output_h, const size_t output_w, |
| 25 | + const size_t num_channels, const T ratio_h, const T ratioW) { |
| 26 | + int nthreads = output_h * output_w; |
| 27 | + int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 28 | + if (tid < nthreads) { |
| 29 | + int out_id_h = tid / output_w; |
| 30 | + int out_id_w = tid % output_w; |
| 31 | + int in_img_size = input_w / num_channels; |
| 32 | + int out_img_size = output_w / num_channels; |
| 33 | + int channel_id = out_id_w / out_img_size; |
| 34 | + |
| 35 | + int out_img_idy = (out_id_w % out_img_size) / out_img_w; |
| 36 | + int in_img_idy = ratio_h * out_img_idy; |
| 37 | + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; |
| 38 | + T h1lambda = ratio_h * out_img_idy - in_img_idy; |
| 39 | + T h2lambda = 1.f - h1lambda; |
| 40 | + |
| 41 | + int out_img_idx = tid % out_img_w; |
| 42 | + int in_img_idx = ratioW * out_img_idx; |
| 43 | + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; |
| 44 | + T w1lambda = ratioW * out_img_idx - in_img_idx; |
| 45 | + T w2lambda = 1.f - w1lambda; |
| 46 | + |
| 47 | + const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + |
| 48 | + in_img_idy * in_img_w + in_img_idx]; |
| 49 | + |
| 50 | + // bilinear interpolation |
| 51 | + out[out_id_h * output_w + out_id_w] = |
| 52 | + h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + |
| 53 | + h1lambda * (w2lambda * in_pos[h_id * in_img_w] + |
| 54 | + w1lambda * in_pos[h_id * in_img_w + w_id]); |
| 55 | + } |
| 56 | +} |
| 57 | + |
| 58 | +template <typename T> |
| 59 | +__global__ void KeBilinearInterpBw( |
| 60 | + T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, |
| 61 | + const size_t input_w, const T* out, const size_t out_img_h, |
| 62 | + const size_t out_img_w, const size_t output_h, const size_t output_w, |
| 63 | + const size_t num_channels, const T ratio_h, const T ratioW) { |
| 64 | + int nthreads = output_h * output_w; |
| 65 | + int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 66 | + if (tid < nthreads) { |
| 67 | + int out_id_h = tid / output_w; |
| 68 | + int out_id_w = tid % output_w; |
| 69 | + int in_img_size = input_w / num_channels; |
| 70 | + int out_img_size = output_w / num_channels; |
| 71 | + int channel_id = out_id_w / out_img_size; |
| 72 | + |
| 73 | + int out_img_idy = (out_id_w % out_img_size) / out_img_w; |
| 74 | + int in_img_idy = ratio_h * out_img_idy; |
| 75 | + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; |
| 76 | + T h1lambda = ratio_h * out_img_idy - in_img_idy; |
| 77 | + T h2lambda = 1.f - h1lambda; |
| 78 | + |
| 79 | + int out_img_idx = tid % out_img_w; |
| 80 | + int in_img_idx = ratioW * out_img_idx; |
| 81 | + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; |
| 82 | + T w1lambda = ratioW * out_img_idx - in_img_idx; |
| 83 | + T w2lambda = 1.f - w1lambda; |
| 84 | + |
| 85 | + T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + |
| 86 | + in_img_idy * in_img_w + in_img_idx]; |
| 87 | + const T* out_pos = &out[out_id_h * output_w + out_id_w]; |
| 88 | + atomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]); |
| 89 | + atomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]); |
| 90 | + atomicAdd(&in_pos[h_id * in_img_w], h1lambda * w2lambda * out_pos[0]); |
| 91 | + atomicAdd(&in_pos[h_id * in_img_w + w_id], |
| 92 | + h1lambda * w1lambda * out_pos[0]); |
| 93 | + } |
| 94 | +} |
| 95 | + |
| 96 | +template <typename T> |
| 97 | +class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> { |
| 98 | + public: |
| 99 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 100 | + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), |
| 101 | + "This kernel only runs on GPU device."); |
| 102 | + auto* input_t = ctx.Input<Tensor>("X"); // float tensor |
| 103 | + auto* output_t = ctx.Output<Tensor>("Out"); // float tensor |
| 104 | + auto* input = input_t->data<T>(); |
| 105 | + auto* output = output_t->mutable_data<T>(ctx.GetPlace()); |
| 106 | + |
| 107 | + int out_h = ctx.Attr<int>("out_h"); |
| 108 | + int out_w = ctx.Attr<int>("out_w"); |
| 109 | + int batch_size = input_t->dims()[0]; |
| 110 | + int channels = input_t->dims()[1]; |
| 111 | + int in_h = input_t->dims()[2]; |
| 112 | + int in_w = input_t->dims()[3]; |
| 113 | + |
| 114 | + int in_hw = in_h * in_w; |
| 115 | + int out_hw = out_h * out_w; |
| 116 | + int in_chw = channels * in_hw; |
| 117 | + int out_chw = channels * out_hw; |
| 118 | + |
| 119 | + T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f; |
| 120 | + T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f; |
| 121 | + |
| 122 | + if (in_h == out_h && in_w == out_w) { |
| 123 | + memcpy(output, input, input_t->numel() * sizeof(T)); |
| 124 | + } else { |
| 125 | + int threadNum = batch_size * out_chw; |
| 126 | + int blocks = (threadNum + 1024 - 1) / 1024; |
| 127 | + |
| 128 | + KeBilinearInterpFw< |
| 129 | + T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>( |
| 130 | + input, in_h, in_w, batch_size, in_chw, output, out_h, out_w, |
| 131 | + batch_size, out_chw, channels, ratio_h, ratio_w); |
| 132 | + } |
| 133 | + } |
| 134 | +}; |
| 135 | + |
| 136 | +template <typename T> |
| 137 | +class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> { |
| 138 | + public: |
| 139 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 140 | + auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X")); |
| 141 | + auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out")); |
| 142 | + auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace()); |
| 143 | + auto* d_output = d_output_t->data<T>(); |
| 144 | + |
| 145 | + auto& device_ctx = |
| 146 | + ctx.template device_context<platform::CUDADeviceContext>(); |
| 147 | + math::SetConstant<platform::CUDADeviceContext, T> zero; |
| 148 | + zero(device_ctx, d_input_t, static_cast<T>(0.0)); |
| 149 | + |
| 150 | + int out_h = ctx.Attr<int>("out_h"); |
| 151 | + int out_w = ctx.Attr<int>("out_w"); |
| 152 | + int batch_size = d_input_t->dims()[0]; |
| 153 | + int channels = d_input_t->dims()[1]; |
| 154 | + int in_h = d_input_t->dims()[2]; |
| 155 | + int in_w = d_input_t->dims()[3]; |
| 156 | + |
| 157 | + int in_hw = in_h * in_w; |
| 158 | + int out_hw = out_h * out_w; |
| 159 | + int in_chw = channels * in_hw; |
| 160 | + int out_chw = channels * out_hw; |
| 161 | + |
| 162 | + T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f; |
| 163 | + T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f; |
| 164 | + |
| 165 | + if (in_h == out_h && in_w == out_w) { |
| 166 | + memcpy(d_input, d_output, d_input_t->numel() * sizeof(T)); |
| 167 | + } else { |
| 168 | + int threadNum = batch_size * out_chw; |
| 169 | + int blocks = (threadNum + 1024 - 1) / 1024; |
| 170 | + |
| 171 | + KeBilinearInterpBw< |
| 172 | + T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>( |
| 173 | + d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w, |
| 174 | + batch_size, out_chw, channels, ratio_h, ratio_w); |
| 175 | + } |
| 176 | + } |
| 177 | +}; |
| 178 | + |
| 179 | +} // namespace operators |
| 180 | +} // namespace paddle |
| 181 | + |
| 182 | +namespace ops = paddle::operators; |
| 183 | +REGISTER_OP_CUDA_KERNEL(bilinear_interp, |
| 184 | + ops::BilinearInterpOpCUDAKernel<float>); |
| 185 | +REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad, |
| 186 | + ops::BilinearInterpGradOpCUDAKernel<float>); |
0 commit comments