Skip to content

Commit 5d33481

Browse files
committed
Add bilinear interp supporting for uint8
1 parent a29cb4b commit 5d33481

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

paddle/fluid/operators/bilinear_interp_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ REGISTER_OPERATOR(bilinear_interp, ops::BilinearInterpOp,
110110
ops::BilinearInterpOpMaker,
111111
paddle::framework::DefaultGradOpDescMaker<true>);
112112
REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad);
113-
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>);
113+
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>,
114+
ops::BilinearInterpKernel<uint8_t>);
114115
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad,
115-
ops::BilinearInterpGradKernel<float>);
116+
ops::BilinearInterpGradKernel<float>,
117+
ops::BilinearInterpGradKernel<uint8_t>);

paddle/fluid/operators/bilinear_interp_op.h

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
4646
int in_chw = channels * in_hw;
4747
int out_chw = channels * out_hw;
4848

49-
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
50-
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
49+
float ratio_h =
50+
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
51+
float ratio_w =
52+
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
5153

5254
if (in_h == out_h && in_w == out_w) {
5355
memcpy(output, input, input_t->numel() * sizeof(T));
@@ -56,14 +58,14 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
5658
for (int i = 0; i < out_h; ++i) { // loop for images
5759
int h = ratio_h * i;
5860
int hid = (h < in_h - 1) ? 1 : 0;
59-
T h1lambda = ratio_h * i - h;
60-
T h2lambda = 1 - h1lambda;
61+
float h1lambda = ratio_h * i - h;
62+
float h2lambda = 1.f - h1lambda;
6163

6264
for (int j = 0; j < out_w; ++j) {
6365
int w = ratio_w * j;
6466
int wid = (w < in_w - 1) ? 1 : 0;
65-
T w1lambda = ratio_w * j - w;
66-
T w2lambda = 1 - w1lambda;
67+
float w1lambda = ratio_w * j - w;
68+
float w2lambda = 1.f - w1lambda;
6769
// calculate four position for bilinear interpolation
6870
const T* in_pos = &input[k * in_chw + h * in_w + w];
6971
T* out_pos = &output[k * out_chw + i * out_w + j];
@@ -117,8 +119,10 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
117119
int in_chw = channels * in_hw;
118120
int out_chw = channels * out_hw;
119121

120-
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
121-
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
122+
float ratio_h =
123+
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
124+
float ratio_w =
125+
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
122126

123127
if (in_h == out_h && in_w == out_w) {
124128
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
@@ -127,14 +131,14 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
127131
for (int i = 0; i < out_h; ++i) { // loop for images
128132
int h = ratio_h * i;
129133
int hid = (h < in_h - 1) ? 1 : 0;
130-
T h1lambda = ratio_h * i - h;
131-
T h2lambda = 1 - h1lambda;
134+
float h1lambda = ratio_h * i - h;
135+
float h2lambda = 1 - h1lambda;
132136

133137
for (int j = 0; j < out_w; ++j) {
134138
int w = ratio_w * j;
135139
int wid = (w < in_w - 1) ? 1 : 0;
136-
T w1lambda = ratio_w * j - w;
137-
T w2lambda = 1 - w1lambda;
140+
float w1lambda = ratio_w * j - w;
141+
float w2lambda = 1 - w1lambda;
138142
T* in_pos = &d_input[k * in_chw + h * in_w + w];
139143
const T* out_pos = &d_output[k * out_chw + i * out_w + j];
140144

0 commit comments

Comments
 (0)