Skip to content

Commit ad3b3d9

Browse files
author
wangyang59
committed
ported old paddle gpu bilinear_interp
1 parent 67ce586 commit ad3b3d9

File tree

2 files changed

+121
-8
lines changed

2 files changed

+121
-8
lines changed

paddle/fluid/operators/bilinear_interp_op.cu

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
See the License for the specific language governing permissions and
1010
limitations under the License. */
1111

12-
#include "hl_cnn.h"
12+
#include "paddle/fluid/operators/bilinear_interp_op.cu.h"
1313
#include "paddle/fluid/operators/bilinear_interp_op.h"
1414

1515
namespace paddle {
@@ -44,9 +44,13 @@ class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> {
4444
if (in_h == out_h && in_w == out_w) {
4545
memcpy(output, input, input_t->numel() * sizeof(T));
4646
} else {
47-
hl_bilinear_forward(input, in_h, in_w, batch_size, in_chw, output, out_h,
48-
out_w, batch_size, out_chw, channels, ratio_h,
49-
ratio_w);
47+
int threadNum = batch_size * out_chw;
48+
int blocks = (threadNum + 1024 - 1) / 1024;
49+
50+
KeBilinearInterpFw<
51+
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
52+
input, in_h, in_w, batch_size, in_chw, output, out_h, out_w,
53+
batch_size, out_chw, channels, ratio_h, ratio_w);
5054
}
5155
}
5256
};
@@ -78,9 +82,13 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
7882
if (in_h == out_h && in_w == out_w) {
7983
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
8084
} else {
81-
hl_bilinear_backward(d_input, in_h, in_w, batch_size, in_chw, d_output,
82-
out_h, out_w, batch_size, out_chw, channels, ratio_h,
83-
ratio_w);
85+
int threadNum = batch_size * out_chw;
86+
int blocks = (threadNum + 1024 - 1) / 1024;
87+
88+
KeBilinearInterpBw<
89+
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
90+
d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w,
91+
batch_size, out_chw, channels, ratio_h, ratio_w);
8492
}
8593
}
8694
};
@@ -92,4 +100,4 @@ namespace ops = paddle::operators;
92100
REGISTER_OP_CUDA_KERNEL(bilinear_interp,
93101
ops::BilinearInterpOpCUDAKernel<float>);
94102
REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad,
95-
ops::BilinearInterpGradOpCUDAKernel<float>);
103+
ops::BilinearInterpGradOpCUDAKernel<float>);
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/tensor.h"
17+
#include "paddle/fluid/platform/cuda_helper.h"
18+
#include "paddle/fluid/platform/place.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
using Tensor = framework::Tensor;
24+
25+
template <typename T>
26+
__global__ void KeBilinearInterpFw(const T* in, const size_t inImgH,
27+
const size_t inImgW, const size_t inputH,
28+
const size_t inputW, T* out,
29+
const size_t outImgH, const size_t outImgW,
30+
const size_t outputH, const size_t outputW,
31+
const size_t numChannels, const T ratioH,
32+
const T ratioW) {
33+
int nthreads = outputH * outputW;
34+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
35+
if (tid < nthreads) {
36+
int outIdH = tid / outputW;
37+
int outIdW = tid % outputW;
38+
int inImgSize = inputW / numChannels;
39+
int outImgSize = outputW / numChannels;
40+
int channelId = outIdW / outImgSize;
41+
42+
int outImgIdy = (outIdW % outImgSize) / outImgW;
43+
int inImgIdy = ratioH * outImgIdy;
44+
int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
45+
T h1lambda = ratioH * outImgIdy - inImgIdy;
46+
T h2lambda = 1.f - h1lambda;
47+
48+
int outImgIdx = tid % outImgW;
49+
int inImgIdx = ratioW * outImgIdx;
50+
int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
51+
T w1lambda = ratioW * outImgIdx - inImgIdx;
52+
T w2lambda = 1.f - w1lambda;
53+
54+
const T* inPos = &in[outIdH * inputW + channelId * inImgSize +
55+
inImgIdy * inImgW + inImgIdx];
56+
57+
// bilinear interpolation
58+
out[outIdH * outputW + outIdW] =
59+
h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wId]) +
60+
h1lambda * (w2lambda * inPos[hId * inImgW] +
61+
w1lambda * inPos[hId * inImgW + wId]);
62+
}
63+
}
64+
65+
template <typename T>
66+
__global__ void KeBilinearInterpBw(T* in, const size_t inImgH,
67+
const size_t inImgW, const size_t inputH,
68+
const size_t inputW, const T* out,
69+
const size_t outImgH, const size_t outImgW,
70+
const size_t outputH, const size_t outputW,
71+
const size_t numChannels, const T ratioH,
72+
const T ratioW) {
73+
int nthreads = outputH * outputW;
74+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
75+
if (tid < nthreads) {
76+
int outIdH = tid / outputW;
77+
int outIdW = tid % outputW;
78+
int inImgSize = inputW / numChannels;
79+
int outImgSize = outputW / numChannels;
80+
int channelId = outIdW / outImgSize;
81+
82+
int outImgIdy = (outIdW % outImgSize) / outImgW;
83+
int inImgIdy = ratioH * outImgIdy;
84+
int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
85+
T h1lambda = ratioH * outImgIdy - inImgIdy;
86+
T h2lambda = 1.f - h1lambda;
87+
88+
int outImgIdx = tid % outImgW;
89+
int inImgIdx = ratioW * outImgIdx;
90+
int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
91+
T w1lambda = ratioW * outImgIdx - inImgIdx;
92+
T w2lambda = 1.f - w1lambda;
93+
94+
T* inPos = &in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW +
95+
inImgIdx];
96+
const T* outPos = &out[outIdH * outputW + outIdW];
97+
atomicAdd(&inPos[0], h2lambda * w2lambda * outPos[0]);
98+
atomicAdd(&inPos[wId], h2lambda * w1lambda * outPos[0]);
99+
atomicAdd(&inPos[hId * inImgW], h1lambda * w2lambda * outPos[0]);
100+
atomicAdd(&inPos[hId * inImgW + wId], h1lambda * w1lambda * outPos[0]);
101+
}
102+
}
103+
104+
} // namespace operators
105+
} // namespace paddle

0 commit comments

Comments
 (0)