|
9 | 9 | See the License for the specific language governing permissions and
|
10 | 10 | limitations under the License. */
|
11 | 11 |
|
12 |
| -#include "paddle/fluid/framework/op_registry.h" |
13 |
| -#include "paddle/fluid/operators/bilinear_interp_op.cu.h" |
14 |
| -#include "paddle/fluid/operators/math/math_function.h" |
| 12 | +#include "paddle/fluid/operators/bilinear_interp_op.h" |
| 13 | +#include "paddle/fluid/platform/cuda_helper.h" |
15 | 14 |
|
16 | 15 | namespace paddle {
|
17 | 16 | namespace operators {
|
18 | 17 |
|
19 | 18 | using framework::Tensor;
|
20 | 19 |
|
| 20 | +template <typename T> |
| 21 | +__global__ void KeBilinearInterpFw( |
| 22 | + const T* in, const size_t in_img_h, const size_t in_img_w, |
| 23 | + const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, |
| 24 | + const size_t out_img_w, const size_t output_h, const size_t output_w, |
| 25 | + const size_t num_channels, const T ratio_h, const T ratioW) { |
| 26 | + int nthreads = output_h * output_w; |
| 27 | + int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 28 | + if (tid < nthreads) { |
| 29 | + int out_id_h = tid / output_w; |
| 30 | + int out_id_w = tid % output_w; |
| 31 | + int in_img_size = input_w / num_channels; |
| 32 | + int out_img_size = output_w / num_channels; |
| 33 | + int channel_id = out_id_w / out_img_size; |
| 34 | + |
| 35 | + int out_img_idy = (out_id_w % out_img_size) / out_img_w; |
| 36 | + int in_img_idy = ratio_h * out_img_idy; |
| 37 | + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; |
| 38 | + T h1lambda = ratio_h * out_img_idy - in_img_idy; |
| 39 | + T h2lambda = 1.f - h1lambda; |
| 40 | + |
| 41 | + int out_img_idx = tid % out_img_w; |
| 42 | + int in_img_idx = ratioW * out_img_idx; |
| 43 | + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; |
| 44 | + T w1lambda = ratioW * out_img_idx - in_img_idx; |
| 45 | + T w2lambda = 1.f - w1lambda; |
| 46 | + |
| 47 | + const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + |
| 48 | + in_img_idy * in_img_w + in_img_idx]; |
| 49 | + |
| 50 | + // bilinear interpolation |
| 51 | + out[out_id_h * output_w + out_id_w] = |
| 52 | + h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + |
| 53 | + h1lambda * (w2lambda * in_pos[h_id * in_img_w] + |
| 54 | + w1lambda * in_pos[h_id * in_img_w + w_id]); |
| 55 | + } |
| 56 | +} |
| 57 | + |
| 58 | +template <typename T> |
| 59 | +__global__ void KeBilinearInterpBw( |
| 60 | + T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, |
| 61 | + const size_t input_w, const T* out, const size_t out_img_h, |
| 62 | + const size_t out_img_w, const size_t output_h, const size_t output_w, |
| 63 | + const size_t num_channels, const T ratio_h, const T ratioW) { |
| 64 | + int nthreads = output_h * output_w; |
| 65 | + int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 66 | + if (tid < nthreads) { |
| 67 | + int out_id_h = tid / output_w; |
| 68 | + int out_id_w = tid % output_w; |
| 69 | + int in_img_size = input_w / num_channels; |
| 70 | + int out_img_size = output_w / num_channels; |
| 71 | + int channel_id = out_id_w / out_img_size; |
| 72 | + |
| 73 | + int out_img_idy = (out_id_w % out_img_size) / out_img_w; |
| 74 | + int in_img_idy = ratio_h * out_img_idy; |
| 75 | + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; |
| 76 | + T h1lambda = ratio_h * out_img_idy - in_img_idy; |
| 77 | + T h2lambda = 1.f - h1lambda; |
| 78 | + |
| 79 | + int out_img_idx = tid % out_img_w; |
| 80 | + int in_img_idx = ratioW * out_img_idx; |
| 81 | + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; |
| 82 | + T w1lambda = ratioW * out_img_idx - in_img_idx; |
| 83 | + T w2lambda = 1.f - w1lambda; |
| 84 | + |
| 85 | + T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + |
| 86 | + in_img_idy * in_img_w + in_img_idx]; |
| 87 | + const T* out_pos = &out[out_id_h * output_w + out_id_w]; |
| 88 | + atomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]); |
| 89 | + atomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]); |
| 90 | + atomicAdd(&in_pos[h_id * in_img_w], h1lambda * w2lambda * out_pos[0]); |
| 91 | + atomicAdd(&in_pos[h_id * in_img_w + w_id], |
| 92 | + h1lambda * w1lambda * out_pos[0]); |
| 93 | + } |
| 94 | +} |
| 95 | + |
21 | 96 | template <typename T>
|
22 | 97 | class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> {
|
23 | 98 | public:
|
|
0 commit comments