Skip to content

Commit 67ce586

Browse files
author
wangyang59
committed
gpu implementation of bilinear interp
1 parent f67f0ca commit 67ce586

File tree

2 files changed

+96
-1
lines changed

2 files changed

+96
-1
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "hl_cnn.h"
13+
#include "paddle/fluid/operators/bilinear_interp_op.h"
14+
15+
namespace paddle {
16+
namespace operators {
17+
18+
template <typename T>
19+
class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> {
20+
public:
21+
void Compute(const framework::ExecutionContext& ctx) const override {
22+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
23+
"This kernel only runs on GPU device.");
24+
auto* input_t = ctx.Input<Tensor>("X"); // float tensor
25+
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
26+
auto* input = input_t->data<T>();
27+
auto* output = output_t->mutable_data<T>(ctx.GetPlace());
28+
29+
int out_h = ctx.Attr<int>("out_h");
30+
int out_w = ctx.Attr<int>("out_w");
31+
int batch_size = input_t->dims()[0];
32+
int channels = input_t->dims()[1];
33+
int in_h = input_t->dims()[2];
34+
int in_w = input_t->dims()[3];
35+
36+
int in_hw = in_h * in_w;
37+
int out_hw = out_h * out_w;
38+
int in_chw = channels * in_hw;
39+
int out_chw = channels * out_hw;
40+
41+
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
42+
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
43+
44+
if (in_h == out_h && in_w == out_w) {
45+
memcpy(output, input, input_t->numel() * sizeof(T));
46+
} else {
47+
hl_bilinear_forward(input, in_h, in_w, batch_size, in_chw, output, out_h,
48+
out_w, batch_size, out_chw, channels, ratio_h,
49+
ratio_w);
50+
}
51+
}
52+
};
53+
54+
template <typename T>
55+
class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
56+
public:
57+
void Compute(const framework::ExecutionContext& ctx) const override {
58+
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
59+
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
60+
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
61+
auto* d_output = d_output_t->data<T>();
62+
63+
int out_h = ctx.Attr<int>("out_h");
64+
int out_w = ctx.Attr<int>("out_w");
65+
int batch_size = d_input_t->dims()[0];
66+
int channels = d_input_t->dims()[1];
67+
int in_h = d_input_t->dims()[2];
68+
int in_w = d_input_t->dims()[3];
69+
70+
int in_hw = in_h * in_w;
71+
int out_hw = out_h * out_w;
72+
int in_chw = channels * in_hw;
73+
int out_chw = channels * out_hw;
74+
75+
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
76+
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
77+
78+
if (in_h == out_h && in_w == out_w) {
79+
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
80+
} else {
81+
hl_bilinear_backward(d_input, in_h, in_w, batch_size, in_chw, d_output,
82+
out_h, out_w, batch_size, out_chw, channels, ratio_h,
83+
ratio_w);
84+
}
85+
}
86+
};
87+
88+
} // namespace operators
89+
} // namespace paddle
90+
91+
namespace ops = paddle::operators;
92+
REGISTER_OP_CUDA_KERNEL(bilinear_interp,
93+
ops::BilinearInterpOpCUDAKernel<float>);
94+
REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad,
95+
ops::BilinearInterpGradOpCUDAKernel<float>);

paddle/fluid/operators/bilinear_interp_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
105105
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
106106

107107
if (in_h == out_h && in_w == out_w) {
108-
memcpy(d_input, d_output, product(d_input_t->dims()) * sizeof(T));
108+
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
109109
} else {
110110
for (int k = 0; k < batch_size; ++k) { // loop for batches
111111
for (int i = 0; i < out_h; ++i) { // loop for images

0 commit comments

Comments
 (0)