|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
| 2 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | + you may not use this file except in compliance with the License. |
| 4 | + You may obtain a copy of the License at |
| 5 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | + Unless required by applicable law or agreed to in writing, software |
| 7 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | + See the License for the specific language governing permissions and |
| 10 | + limitations under the License. */ |
| 11 | + |
| 12 | +#pragma once |
| 13 | +#include "paddle/fluid/framework/eigen.h" |
| 14 | +#include "paddle/fluid/framework/op_registry.h" |
| 15 | + |
| 16 | +namespace paddle { |
| 17 | +namespace operators { |
| 18 | + |
| 19 | +using Tensor = framework::Tensor; |
| 20 | +template <typename T, int MajorType = Eigen::RowMajor, |
| 21 | + typename IndexType = Eigen::DenseIndex> |
| 22 | +using EigenVector = framework::EigenVector<T, MajorType, IndexType>; |
| 23 | + |
| 24 | +template <typename T> |
| 25 | +class BilinearInterpKernel : public framework::OpKernel<T> { |
| 26 | + public: |
| 27 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 28 | + auto* input_t = ctx.Input<Tensor>("X"); // float tensor |
| 29 | + auto* output_t = ctx.Output<Tensor>("Out"); // float tensor |
| 30 | + auto* input = input_t->data<T>(); |
| 31 | + auto* output = output_t->mutable_data<T>(ctx.GetPlace()); |
| 32 | + |
| 33 | + int out_h = ctx.Attr<int>("out_h"); |
| 34 | + int out_w = ctx.Attr<int>("out_w"); |
| 35 | + int batch_size = input_t->dims()[0]; |
| 36 | + int channels = input_t->dims()[1]; |
| 37 | + int in_h = input_t->dims()[2]; |
| 38 | + int in_w = input_t->dims()[3]; |
| 39 | + |
| 40 | + int in_hw = in_h * in_w; |
| 41 | + int out_hw = out_h * out_w; |
| 42 | + int in_chw = channels * in_hw; |
| 43 | + int out_chw = channels * out_hw; |
| 44 | + |
| 45 | + T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f; |
| 46 | + T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f; |
| 47 | + |
| 48 | + if (in_h == out_h && in_w == out_w) { |
| 49 | + memcpy(output, input, product(input_t->dims()) * sizeof(T)); |
| 50 | + } else { |
| 51 | + for (int k = 0; k < batch_size; ++k) { // loop for batches |
| 52 | + for (int i = 0; i < out_h; ++i) { // loop for images |
| 53 | + int h = ratio_h * i; |
| 54 | + int hid = (h < in_h - 1) ? 1 : 0; |
| 55 | + T h1lambda = ratio_h * i - h; |
| 56 | + T h2lambda = 1 - h1lambda; |
| 57 | + |
| 58 | + for (int j = 0; j < out_w; ++j) { |
| 59 | + int w = ratio_w * j; |
| 60 | + int wid = (w < in_w - 1) ? 1 : 0; |
| 61 | + T w1lambda = ratio_w * j - w; |
| 62 | + T w2lambda = 1 - w1lambda; |
| 63 | + // calculate four position for bilinear interpolation |
| 64 | + const T* in_pos = &input[k * in_chw + h * in_w + w]; |
| 65 | + T* out_pos = &output[k * out_chw + i * out_w + j]; |
| 66 | + |
| 67 | + for (int c = 0; c < channels; ++c) { // loop for channels |
| 68 | + // bilinear interpolation |
| 69 | + out_pos[0] = |
| 70 | + h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) + |
| 71 | + h1lambda * (w2lambda * in_pos[hid * in_w] + |
| 72 | + w1lambda * in_pos[hid * in_w + wid]); |
| 73 | + in_pos += in_hw; |
| 74 | + out_pos += out_hw; |
| 75 | + } |
| 76 | + } |
| 77 | + } |
| 78 | + } |
| 79 | + } |
| 80 | + } |
| 81 | +}; |
| 82 | + |
| 83 | +template <typename T> |
| 84 | +class BilinearInterpGradKernel : public framework::OpKernel<T> { |
| 85 | + public: |
| 86 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 87 | + auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X")); |
| 88 | + auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out")); |
| 89 | + auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace()); |
| 90 | + auto* d_output = d_output_t->data<T>(); |
| 91 | + |
| 92 | + int out_h = ctx.Attr<int>("out_h"); |
| 93 | + int out_w = ctx.Attr<int>("out_w"); |
| 94 | + int batch_size = d_input_t->dims()[0]; |
| 95 | + int channels = d_input_t->dims()[1]; |
| 96 | + int in_h = d_input_t->dims()[2]; |
| 97 | + int in_w = d_input_t->dims()[3]; |
| 98 | + |
| 99 | + int in_hw = in_h * in_w; |
| 100 | + int out_hw = out_h * out_w; |
| 101 | + int in_chw = channels * in_hw; |
| 102 | + int out_chw = channels * out_hw; |
| 103 | + |
| 104 | + T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f; |
| 105 | + T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f; |
| 106 | + |
| 107 | + if (in_h == out_h && in_w == out_w) { |
| 108 | + memcpy(d_input, d_output, product(d_input_t->dims()) * sizeof(T)); |
| 109 | + } else { |
| 110 | + for (int k = 0; k < batch_size; ++k) { // loop for batches |
| 111 | + for (int i = 0; i < out_h; ++i) { // loop for images |
| 112 | + int h = ratio_h * i; |
| 113 | + int hid = (h < in_h - 1) ? 1 : 0; |
| 114 | + T h1lambda = ratio_h * i - h; |
| 115 | + T h2lambda = 1 - h1lambda; |
| 116 | + |
| 117 | + for (int j = 0; j < out_w; ++j) { |
| 118 | + int w = ratio_w * j; |
| 119 | + int wid = (w < in_w - 1) ? 1 : 0; |
| 120 | + T w1lambda = ratio_w * j - w; |
| 121 | + T w2lambda = 1 - w1lambda; |
| 122 | + T* in_pos = &d_input[k * in_chw + h * in_w + w]; |
| 123 | + const T* out_pos = &d_output[k * out_chw + i * out_w + j]; |
| 124 | + |
| 125 | + for (int c = 0; c < channels; ++c) { // loop for channels |
| 126 | + in_pos[0] = h2lambda * w2lambda * out_pos[0]; |
| 127 | + in_pos[wid] = h2lambda * w1lambda * out_pos[0]; |
| 128 | + in_pos[hid * in_w] = h1lambda * w2lambda * out_pos[0]; |
| 129 | + in_pos[hid * in_w + wid] = h1lambda * w1lambda * out_pos[0]; |
| 130 | + in_pos += in_hw; |
| 131 | + out_pos += out_hw; |
| 132 | + } |
| 133 | + } |
| 134 | + } |
| 135 | + } |
| 136 | + } |
| 137 | + } |
| 138 | +}; |
| 139 | + |
| 140 | +} // namespace operators |
| 141 | +} // namespace paddle |
0 commit comments