Skip to content

Commit c7cd6d1

Browse files
author
wangyang59
committed
cpu implement of bilinear interp
1 parent 504e60a commit c7cd6d1

File tree

2 files changed

+227
-0
lines changed

2 files changed

+227
-0
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/bilinear_interp_op.h"
13+
14+
namespace paddle {
15+
namespace operators {
16+
17+
using framework::Tensor;
18+
19+
class BilinearInterpOp : public framework::OperatorWithKernel {
20+
public:
21+
using framework::OperatorWithKernel::OperatorWithKernel;
22+
23+
protected:
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("X"),
26+
"Input(X) of BilinearInterOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
28+
"Output(Out) of BilinearInterOp should not be null.");
29+
30+
auto dim_x = ctx->GetInputDim("Input"); // NCHW format
31+
int out_h = ctx->Attrs().Get<int>("out_h");
32+
int out_w = ctx->Attrs().Get<int>("out_w");
33+
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
34+
35+
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
36+
ctx->SetOutputDim("Output", framework::make_ddim(dim_out));
37+
}
38+
};
39+
40+
class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
41+
public:
42+
BilinearInterpOpMaker(OpProto* proto, OpAttrChecker* op_checker)
43+
: OpProtoAndCheckerMaker(proto, op_checker) {
44+
AddInput("X",
45+
"The input tensor of bilinear interpolation, 4-D with NCHW shape");
46+
AddOutput("Out", "The output tensor with the same shape as X");
47+
AddAttr<int>("out_h", "output height of bilinear interpolation op.");
48+
AddAttr<int>("out_w", "output weight of bilinear interpolation op.");
49+
AddComment(R"DOC(
50+
Bilinear interpolation is an extension of linear interpolation for
51+
interpolating functions of two variables (e.g. H-direction and W-direction
52+
in this op) on a rectilinear 2D grid.
53+
54+
The key idea is to perform linear interpolation first in one direction,
55+
and then again in the other direction.
56+
57+
For details, please refer to Wikipedia:
58+
https://en.wikipedia.org/wiki/Bilinear_interpolation
59+
)DOC");
60+
}
61+
};
62+
63+
class BilinearInterpOpGrad : public framework::OperatorWithKernel {
64+
public:
65+
using framework::OperatorWithKernel::OperatorWithKernel;
66+
67+
protected:
68+
void InferShape(framework::InferShapeContext* ctx) const override {
69+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
70+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
71+
"Input(Out@GRAD) should not be null");
72+
auto dim_x = ctx->GetInputDim("X");
73+
if (ctx->HasOutput(framework::GradVarName("X"))) {
74+
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
75+
}
76+
}
77+
};
78+
79+
} // namespace operators
80+
} // namespace paddle
81+
82+
namespace ops = paddle::operators;
83+
REGISTER_OP(bilinear_interp, ops::BilinearInterpOp, ops::BilinearInterpOpMaker,
84+
bilinear_interp_grad, ops::BilinearInterpOpGrad);
85+
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>);
86+
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, ops::BilinearInterpKernel<float>);
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#pragma once
13+
#include "paddle/fluid/framework/eigen.h"
14+
#include "paddle/fluid/framework/op_registry.h"
15+
16+
namespace paddle {
17+
namespace operators {
18+
19+
using Tensor = framework::Tensor;
20+
template <typename T, int MajorType = Eigen::RowMajor,
21+
typename IndexType = Eigen::DenseIndex>
22+
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
23+
24+
template <typename T>
25+
class BilinearInterpKernel : public framework::OpKernel<T> {
26+
public:
27+
void Compute(const framework::ExecutionContext& ctx) const override {
28+
auto* input_t = ctx.Input<Tensor>("X"); // float tensor
29+
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
30+
auto* input = input_t->data<T>();
31+
auto* output = output_t->mutable_data<T>(ctx.GetPlace());
32+
33+
int out_h = ctx.Attr<int>("out_h");
34+
int out_w = ctx.Attr<int>("out_w");
35+
int batch_size = input_t->dims()[0];
36+
int channels = input_t->dims()[1];
37+
int in_h = input_t->dims()[2];
38+
int in_w = input_t->dims()[3];
39+
40+
int in_hw = in_h * in_w;
41+
int out_hw = out_h * out_w;
42+
int in_chw = channels * in_hw;
43+
int out_chw = channels * out_hw;
44+
45+
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
46+
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
47+
48+
if (in_h == out_h && in_w == out_w) {
49+
memcpy(output, input, product(input_t->dims()) * sizeof(T));
50+
} else {
51+
for (int k = 0; k < batch_size; ++k) { // loop for batches
52+
for (int i = 0; i < out_h; ++i) { // loop for images
53+
int h = ratio_h * i;
54+
int hid = (h < in_h - 1) ? 1 : 0;
55+
T h1lambda = ratio_h * i - h;
56+
T h2lambda = 1 - h1lambda;
57+
58+
for (int j = 0; j < out_w; ++j) {
59+
int w = ratio_w * j;
60+
int wid = (w < in_w - 1) ? 1 : 0;
61+
T w1lambda = ratio_w * j - w;
62+
T w2lambda = 1 - w1lambda;
63+
// calculate four position for bilinear interpolation
64+
const T* in_pos = &input[k * in_chw + h * in_w + w];
65+
T* out_pos = &output[k * out_chw + i * out_w + j];
66+
67+
for (int c = 0; c < channels; ++c) { // loop for channels
68+
// bilinear interpolation
69+
out_pos[0] =
70+
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) +
71+
h1lambda * (w2lambda * in_pos[hid * in_w] +
72+
w1lambda * in_pos[hid * in_w + wid]);
73+
in_pos += in_hw;
74+
out_pos += out_hw;
75+
}
76+
}
77+
}
78+
}
79+
}
80+
}
81+
};
82+
83+
template <typename T>
84+
class BilinearInterpGradKernel : public framework::OpKernel<T> {
85+
public:
86+
void Compute(const framework::ExecutionContext& ctx) const override {
87+
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
88+
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
89+
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
90+
auto* d_output = d_output_t->data<T>();
91+
92+
int out_h = ctx.Attr<int>("out_h");
93+
int out_w = ctx.Attr<int>("out_w");
94+
int batch_size = d_input_t->dims()[0];
95+
int channels = d_input_t->dims()[1];
96+
int in_h = d_input_t->dims()[2];
97+
int in_w = d_input_t->dims()[3];
98+
99+
int in_hw = in_h * in_w;
100+
int out_hw = out_h * out_w;
101+
int in_chw = channels * in_hw;
102+
int out_chw = channels * out_hw;
103+
104+
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
105+
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
106+
107+
if (in_h == out_h && in_w == out_w) {
108+
memcpy(d_input, d_output, product(d_input_t->dims()) * sizeof(T));
109+
} else {
110+
for (int k = 0; k < batch_size; ++k) { // loop for batches
111+
for (int i = 0; i < out_h; ++i) { // loop for images
112+
int h = ratio_h * i;
113+
int hid = (h < in_h - 1) ? 1 : 0;
114+
T h1lambda = ratio_h * i - h;
115+
T h2lambda = 1 - h1lambda;
116+
117+
for (int j = 0; j < out_w; ++j) {
118+
int w = ratio_w * j;
119+
int wid = (w < in_w - 1) ? 1 : 0;
120+
T w1lambda = ratio_w * j - w;
121+
T w2lambda = 1 - w1lambda;
122+
T* in_pos = &d_input[k * in_chw + h * in_w + w];
123+
const T* out_pos = &d_output[k * out_chw + i * out_w + j];
124+
125+
for (int c = 0; c < channels; ++c) { // loop for channels
126+
in_pos[0] = h2lambda * w2lambda * out_pos[0];
127+
in_pos[wid] = h2lambda * w1lambda * out_pos[0];
128+
in_pos[hid * in_w] = h1lambda * w2lambda * out_pos[0];
129+
in_pos[hid * in_w + wid] = h1lambda * w1lambda * out_pos[0];
130+
in_pos += in_hw;
131+
out_pos += out_hw;
132+
}
133+
}
134+
}
135+
}
136+
}
137+
}
138+
};
139+
140+
} // namespace operators
141+
} // namespace paddle

0 commit comments

Comments
 (0)