Skip to content

Commit 6e1c48d

Browse files
authored
Merge pull request #11576 from JiayiFeng/dev_refine_bilinear_interp
Add bilinear interp supporting for uint8
2 parents 80f6364 + e418862 commit 6e1c48d

File tree

5 files changed

+71
-23
lines changed

5 files changed

+71
-23
lines changed

paddle/fluid/operators/bilinear_interp_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ REGISTER_OPERATOR(bilinear_interp, ops::BilinearInterpOp,
110110
ops::BilinearInterpOpMaker,
111111
paddle::framework::DefaultGradOpDescMaker<true>);
112112
REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad);
113-
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>);
113+
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>,
114+
ops::BilinearInterpKernel<uint8_t>);
114115
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad,
115116
ops::BilinearInterpGradKernel<float>);

paddle/fluid/operators/bilinear_interp_op.h

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
4646
int in_chw = channels * in_hw;
4747
int out_chw = channels * out_hw;
4848

49-
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
50-
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
49+
float ratio_h =
50+
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
51+
float ratio_w =
52+
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
5153

5254
if (in_h == out_h && in_w == out_w) {
5355
memcpy(output, input, input_t->numel() * sizeof(T));
@@ -56,24 +58,24 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
5658
for (int i = 0; i < out_h; ++i) { // loop for images
5759
int h = ratio_h * i;
5860
int hid = (h < in_h - 1) ? 1 : 0;
59-
T h1lambda = ratio_h * i - h;
60-
T h2lambda = 1 - h1lambda;
61+
float h1lambda = ratio_h * i - h;
62+
float h2lambda = 1.f - h1lambda;
6163

6264
for (int j = 0; j < out_w; ++j) {
6365
int w = ratio_w * j;
6466
int wid = (w < in_w - 1) ? 1 : 0;
65-
T w1lambda = ratio_w * j - w;
66-
T w2lambda = 1 - w1lambda;
67+
float w1lambda = ratio_w * j - w;
68+
float w2lambda = 1.f - w1lambda;
6769
// calculate four position for bilinear interpolation
6870
const T* in_pos = &input[k * in_chw + h * in_w + w];
6971
T* out_pos = &output[k * out_chw + i * out_w + j];
7072

7173
for (int c = 0; c < channels; ++c) { // loop for channels
7274
// bilinear interpolation
73-
out_pos[0] =
75+
out_pos[0] = static_cast<T>(
7476
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) +
7577
h1lambda * (w2lambda * in_pos[hid * in_w] +
76-
w1lambda * in_pos[hid * in_w + wid]);
78+
w1lambda * in_pos[hid * in_w + wid]));
7779
in_pos += in_hw;
7880
out_pos += out_hw;
7981
}
@@ -117,8 +119,10 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
117119
int in_chw = channels * in_hw;
118120
int out_chw = channels * out_hw;
119121

120-
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
121-
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
122+
float ratio_h =
123+
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
124+
float ratio_w =
125+
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
122126

123127
if (in_h == out_h && in_w == out_w) {
124128
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
@@ -127,22 +131,24 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
127131
for (int i = 0; i < out_h; ++i) { // loop for images
128132
int h = ratio_h * i;
129133
int hid = (h < in_h - 1) ? 1 : 0;
130-
T h1lambda = ratio_h * i - h;
131-
T h2lambda = 1 - h1lambda;
134+
float h1lambda = ratio_h * i - h;
135+
float h2lambda = 1 - h1lambda;
132136

133137
for (int j = 0; j < out_w; ++j) {
134138
int w = ratio_w * j;
135139
int wid = (w < in_w - 1) ? 1 : 0;
136-
T w1lambda = ratio_w * j - w;
137-
T w2lambda = 1 - w1lambda;
140+
float w1lambda = ratio_w * j - w;
141+
float w2lambda = 1 - w1lambda;
138142
T* in_pos = &d_input[k * in_chw + h * in_w + w];
139143
const T* out_pos = &d_output[k * out_chw + i * out_w + j];
140144

141145
for (int c = 0; c < channels; ++c) { // loop for channels
142-
in_pos[0] += h2lambda * w2lambda * out_pos[0];
143-
in_pos[wid] += h2lambda * w1lambda * out_pos[0];
144-
in_pos[hid * in_w] += h1lambda * w2lambda * out_pos[0];
145-
in_pos[hid * in_w + wid] += h1lambda * w1lambda * out_pos[0];
146+
in_pos[0] += static_cast<T>(h2lambda * w2lambda * out_pos[0]);
147+
in_pos[wid] += static_cast<T>(h2lambda * w1lambda * out_pos[0]);
148+
in_pos[hid * in_w] +=
149+
static_cast<T>(h1lambda * w2lambda * out_pos[0]);
150+
in_pos[hid * in_w + wid] +=
151+
static_cast<T>(h1lambda * w1lambda * out_pos[0]);
146152
in_pos += in_hw;
147153
out_pos += out_hw;
148154
}

paddle/fluid/operators/math/math_function.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ template struct SetConstant<platform::CPUDeviceContext, double>;
3030
template struct SetConstant<platform::CPUDeviceContext, int>;
3131
template struct SetConstant<platform::CPUDeviceContext, int64_t>;
3232
template struct SetConstant<platform::CPUDeviceContext, bool>;
33+
template struct SetConstant<platform::CPUDeviceContext, uint8_t>;
3334

3435
#define DEFINE_CPU_TRANS(RANK) \
3536
template struct Transpose<platform::CPUDeviceContext, platform::float16, \

paddle/fluid/pybind/tensor_py.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
9797
inline pybind11::buffer_info CastToPyBuffer(const framework::Tensor &tensor) {
9898
auto buffer_info =
9999
details::CastToPyBufferImpl<true, 0, float, int, double, int64_t, bool,
100-
platform::float16>()(tensor);
100+
uint8_t, platform::float16>()(tensor);
101101
return buffer_info;
102102
}
103103

python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import unittest
1616
import numpy as np
1717
from op_test import OpTest
18+
import paddle.fluid.core as core
1819

1920

2021
def bilinear_interp_np(input, out_h, out_w, out_size):
@@ -45,9 +46,9 @@ def bilinear_interp_np(input, out_h, out_w, out_size):
4546

4647
out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] +
4748
w1lambda*input[:, :, h, w+wid]) + \
48-
h1lambda*(w2lambda*input[:, :, h+hid, w] +
49-
w1lambda*input[:, :, h+hid, w+wid])
50-
return out.astype("float32")
49+
h1lambda*(w2lambda*input[:, :, h+hid, w] +
50+
w1lambda*input[:, :, h+hid, w+wid])
51+
return out.astype(input.dtype)
5152

5253

5354
class TestBilinearInterpOp(OpTest):
@@ -122,5 +123,44 @@ def init_test_case(self):
122123
self.out_size = np.array([65, 129]).astype("int32")
123124

124125

126+
class TestBilinearInterpOpUint8(OpTest):
127+
def setUp(self):
128+
self.out_size = None
129+
self.init_test_case()
130+
self.op_type = "bilinear_interp"
131+
input_np = np.random.randint(
132+
low=0, high=256, size=self.input_shape).astype("uint8")
133+
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w,
134+
self.out_size)
135+
self.inputs = {'X': input_np}
136+
if self.out_size is not None:
137+
self.inputs['OutSize'] = self.out_size
138+
self.attrs = {'out_h': self.out_h, 'out_w': self.out_w}
139+
self.outputs = {'Out': output_np}
140+
141+
def test_check_output(self):
142+
self.check_output_with_place(place=core.CPUPlace(), atol=1)
143+
144+
def init_test_case(self):
145+
self.input_shape = [1, 3, 9, 6]
146+
self.out_h = 10
147+
self.out_w = 9
148+
149+
150+
class TestCase1Uint8(TestBilinearInterpOpUint8):
151+
def init_test_case(self):
152+
self.input_shape = [2, 3, 128, 64]
153+
self.out_h = 120
154+
self.out_w = 50
155+
156+
157+
class TestCase2Uint8(TestBilinearInterpOpUint8):
158+
def init_test_case(self):
159+
self.input_shape = [4, 1, 7, 8]
160+
self.out_h = 5
161+
self.out_w = 13
162+
self.out_size = np.array([6, 15]).astype("int32")
163+
164+
125165
if __name__ == "__main__":
126166
unittest.main()

0 commit comments

Comments
 (0)