Skip to content

Commit 72ee737

Browse files
author
wangyang59
authored
Merge pull request #9308 from wangyang59/bilinear
Bilinear interp op
2 parents 2182ecf + bf021f3 commit 72ee737

File tree

4 files changed

+518
-0
lines changed

4 files changed

+518
-0
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/bilinear_interp_op.h"
13+
#include <vector>
14+
#include "paddle/fluid/framework/op_registry.h"
15+
16+
namespace paddle {
17+
namespace operators {
18+
19+
using framework::Tensor;
20+
21+
class BilinearInterpOp : public framework::OperatorWithKernel {
22+
public:
23+
using framework::OperatorWithKernel::OperatorWithKernel;
24+
25+
protected:
26+
void InferShape(framework::InferShapeContext* ctx) const override {
27+
PADDLE_ENFORCE(ctx->HasInput("X"),
28+
"Input(X) of BilinearInterOp should not be null.");
29+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
30+
"Output(Out) of BilinearInterOp should not be null.");
31+
32+
auto dim_x = ctx->GetInputDim("X"); // NCHW format
33+
int out_h = ctx->Attrs().Get<int>("out_h");
34+
int out_w = ctx->Attrs().Get<int>("out_w");
35+
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
36+
37+
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
38+
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
39+
}
40+
};
41+
42+
class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
43+
public:
44+
BilinearInterpOpMaker(OpProto* proto, OpAttrChecker* op_checker)
45+
: OpProtoAndCheckerMaker(proto, op_checker) {
46+
AddInput("X",
47+
"(Tensor) The input tensor of bilinear interpolation, "
48+
"This is a 4-D tensor with shape of (N x C x h x w)");
49+
AddOutput("Out",
50+
"(Tensor) The dimension of output is (N x C x out_h x out_w]");
51+
52+
AddAttr<int>("out_h", "(int) output height of bilinear interpolation op.");
53+
AddAttr<int>("out_w", "(int) output width of bilinear interpolation op.");
54+
AddComment(R"DOC(
55+
Bilinear interpolation is an extension of linear interpolation for
56+
interpolating functions of two variables (e.g. H-direction and
57+
W-direction in this op) on a rectilinear 2D grid.
58+
59+
The key idea is to perform linear interpolation first in one
60+
direction, and then again in the other direction.
61+
62+
For details, please refer to Wikipedia:
63+
https://en.wikipedia.org/wiki/Bilinear_interpolation
64+
)DOC");
65+
}
66+
};
67+
68+
class BilinearInterpOpGrad : public framework::OperatorWithKernel {
69+
public:
70+
using framework::OperatorWithKernel::OperatorWithKernel;
71+
72+
protected:
73+
void InferShape(framework::InferShapeContext* ctx) const override {
74+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
75+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
76+
"Input(Out@GRAD) should not be null");
77+
auto dim_x = ctx->GetInputDim("X");
78+
if (ctx->HasOutput(framework::GradVarName("X"))) {
79+
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
80+
}
81+
}
82+
};
83+
84+
} // namespace operators
85+
} // namespace paddle
86+
87+
namespace ops = paddle::operators;
88+
REGISTER_OPERATOR(bilinear_interp, ops::BilinearInterpOp,
89+
ops::BilinearInterpOpMaker,
90+
paddle::framework::DefaultGradOpDescMaker<true>);
91+
REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad);
92+
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>);
93+
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad,
94+
ops::BilinearInterpGradKernel<float>);
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/bilinear_interp_op.h"
13+
#include "paddle/fluid/platform/cuda_helper.h"
14+
15+
namespace paddle {
16+
namespace operators {
17+
18+
using framework::Tensor;
19+
20+
template <typename T>
21+
__global__ void KeBilinearInterpFw(
22+
const T* in, const size_t in_img_h, const size_t in_img_w,
23+
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
24+
const size_t out_img_w, const size_t output_h, const size_t output_w,
25+
const size_t num_channels, const T ratio_h, const T ratioW) {
26+
int nthreads = output_h * output_w;
27+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
28+
if (tid < nthreads) {
29+
int out_id_h = tid / output_w;
30+
int out_id_w = tid % output_w;
31+
int in_img_size = input_w / num_channels;
32+
int out_img_size = output_w / num_channels;
33+
int channel_id = out_id_w / out_img_size;
34+
35+
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
36+
int in_img_idy = ratio_h * out_img_idy;
37+
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
38+
T h1lambda = ratio_h * out_img_idy - in_img_idy;
39+
T h2lambda = 1.f - h1lambda;
40+
41+
int out_img_idx = tid % out_img_w;
42+
int in_img_idx = ratioW * out_img_idx;
43+
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
44+
T w1lambda = ratioW * out_img_idx - in_img_idx;
45+
T w2lambda = 1.f - w1lambda;
46+
47+
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
48+
in_img_idy * in_img_w + in_img_idx];
49+
50+
// bilinear interpolation
51+
out[out_id_h * output_w + out_id_w] =
52+
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
53+
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
54+
w1lambda * in_pos[h_id * in_img_w + w_id]);
55+
}
56+
}
57+
58+
template <typename T>
59+
__global__ void KeBilinearInterpBw(
60+
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
61+
const size_t input_w, const T* out, const size_t out_img_h,
62+
const size_t out_img_w, const size_t output_h, const size_t output_w,
63+
const size_t num_channels, const T ratio_h, const T ratioW) {
64+
int nthreads = output_h * output_w;
65+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
66+
if (tid < nthreads) {
67+
int out_id_h = tid / output_w;
68+
int out_id_w = tid % output_w;
69+
int in_img_size = input_w / num_channels;
70+
int out_img_size = output_w / num_channels;
71+
int channel_id = out_id_w / out_img_size;
72+
73+
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
74+
int in_img_idy = ratio_h * out_img_idy;
75+
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
76+
T h1lambda = ratio_h * out_img_idy - in_img_idy;
77+
T h2lambda = 1.f - h1lambda;
78+
79+
int out_img_idx = tid % out_img_w;
80+
int in_img_idx = ratioW * out_img_idx;
81+
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
82+
T w1lambda = ratioW * out_img_idx - in_img_idx;
83+
T w2lambda = 1.f - w1lambda;
84+
85+
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
86+
in_img_idy * in_img_w + in_img_idx];
87+
const T* out_pos = &out[out_id_h * output_w + out_id_w];
88+
atomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
89+
atomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]);
90+
atomicAdd(&in_pos[h_id * in_img_w], h1lambda * w2lambda * out_pos[0]);
91+
atomicAdd(&in_pos[h_id * in_img_w + w_id],
92+
h1lambda * w1lambda * out_pos[0]);
93+
}
94+
}
95+
96+
template <typename T>
97+
class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> {
98+
public:
99+
void Compute(const framework::ExecutionContext& ctx) const override {
100+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
101+
"This kernel only runs on GPU device.");
102+
auto* input_t = ctx.Input<Tensor>("X"); // float tensor
103+
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
104+
auto* input = input_t->data<T>();
105+
auto* output = output_t->mutable_data<T>(ctx.GetPlace());
106+
107+
int out_h = ctx.Attr<int>("out_h");
108+
int out_w = ctx.Attr<int>("out_w");
109+
int batch_size = input_t->dims()[0];
110+
int channels = input_t->dims()[1];
111+
int in_h = input_t->dims()[2];
112+
int in_w = input_t->dims()[3];
113+
114+
int in_hw = in_h * in_w;
115+
int out_hw = out_h * out_w;
116+
int in_chw = channels * in_hw;
117+
int out_chw = channels * out_hw;
118+
119+
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
120+
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
121+
122+
if (in_h == out_h && in_w == out_w) {
123+
memcpy(output, input, input_t->numel() * sizeof(T));
124+
} else {
125+
int threadNum = batch_size * out_chw;
126+
int blocks = (threadNum + 1024 - 1) / 1024;
127+
128+
KeBilinearInterpFw<
129+
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
130+
input, in_h, in_w, batch_size, in_chw, output, out_h, out_w,
131+
batch_size, out_chw, channels, ratio_h, ratio_w);
132+
}
133+
}
134+
};
135+
136+
template <typename T>
137+
class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
138+
public:
139+
void Compute(const framework::ExecutionContext& ctx) const override {
140+
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
141+
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
142+
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
143+
auto* d_output = d_output_t->data<T>();
144+
145+
auto& device_ctx =
146+
ctx.template device_context<platform::CUDADeviceContext>();
147+
math::SetConstant<platform::CUDADeviceContext, T> zero;
148+
zero(device_ctx, d_input_t, static_cast<T>(0.0));
149+
150+
int out_h = ctx.Attr<int>("out_h");
151+
int out_w = ctx.Attr<int>("out_w");
152+
int batch_size = d_input_t->dims()[0];
153+
int channels = d_input_t->dims()[1];
154+
int in_h = d_input_t->dims()[2];
155+
int in_w = d_input_t->dims()[3];
156+
157+
int in_hw = in_h * in_w;
158+
int out_hw = out_h * out_w;
159+
int in_chw = channels * in_hw;
160+
int out_chw = channels * out_hw;
161+
162+
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
163+
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
164+
165+
if (in_h == out_h && in_w == out_w) {
166+
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
167+
} else {
168+
int threadNum = batch_size * out_chw;
169+
int blocks = (threadNum + 1024 - 1) / 1024;
170+
171+
KeBilinearInterpBw<
172+
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
173+
d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w,
174+
batch_size, out_chw, channels, ratio_h, ratio_w);
175+
}
176+
}
177+
};
178+
179+
} // namespace operators
180+
} // namespace paddle
181+
182+
namespace ops = paddle::operators;
183+
REGISTER_OP_CUDA_KERNEL(bilinear_interp,
184+
ops::BilinearInterpOpCUDAKernel<float>);
185+
REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad,
186+
ops::BilinearInterpGradOpCUDAKernel<float>);

0 commit comments

Comments
 (0)