Skip to content

Commit e402c0e

Browse files
committed
test=develop
1 parent 334f697 commit e402c0e

File tree

7 files changed

+551
-107
lines changed

7 files changed

+551
-107
lines changed

paddle/fluid/API.spec

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,10 @@ paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon',
142142
paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0))
143143
paddle.fluid.layers.roi_align ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None))
144144
paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,))
145-
paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None))
145+
paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1))
146146
paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',))
147-
paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, None))
148-
paddle.fluid.layers.resize_nearest ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, None))
147+
paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1))
148+
paddle.fluid.layers.resize_nearest ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True))
149149
paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None)
150150
paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,))
151151
paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,))

paddle/fluid/operators/interpolate_op.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
8282
"bilinear interpolation and \"nearest\" for nearest "
8383
"neighbor interpolation.")
8484
.SetDefault("bilinear");
85+
AddAttr<bool>(
86+
"align_corners",
87+
"an optinal bool. Defaults to True. "
88+
"If True, the centers of 4 corner pixels of the input and output "
89+
"tensors are aligned, preserving the values at the corner pixels, "
90+
"if Flase, are not aligned")
91+
.SetDefault(true);
92+
AddAttr<int>("align_mode",
93+
"(int, default \'1\'), optional for bilinear interpolation"
94+
"can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , "
95+
"can be \'1\' for src_idx = scale*dst_index .")
96+
.SetDefault(1);
8597
AddComment(R"DOC(
8698
This operator samples input X to given output shape by using specified
8799
interpolation method, the interpolation methods can be \"nearest\"
@@ -98,6 +110,64 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
98110
to perform linear interpolation first in one direction, and then
99111
again in the other direction.
100112
113+
Align_corners and align_mode are optinal parameters,the calculation method
114+
of interpolation can be selected by them.
115+
116+
Example:
117+
118+
For scale:
119+
120+
if align_corners = True and out_{size}>1 :
121+
122+
scale_{factor} = (in_{size}-1.0)/(out_{size}-1.0)
123+
124+
else:
125+
126+
scale_{factor} = float(in_{size}/out_{size})
127+
128+
129+
Nearest neighbor interpolation:
130+
131+
if:
132+
align_corners = False
133+
134+
input : (N,C,H_in,W_in)
135+
output: (N,C,H_out,W_out) where:
136+
137+
H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor
138+
W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor
139+
140+
else:
141+
align_corners = True
142+
143+
input : (N,C,H_in,W_in)
144+
output: (N,C,H_out,W_out) where:
145+
146+
H_out = round(H_{in} * scale_{factor})
147+
W_out = round(W_{in} * scale_{factor})
148+
149+
Bilinear interpolation:
150+
151+
if:
152+
align_corners = False , align_mode = 0
153+
154+
input : (N,C,H_in,W_in)
155+
output: (N,C,H_out,W_out) where:
156+
157+
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
158+
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
159+
160+
161+
else:
162+
163+
input : (N,C,H_in,W_in)
164+
output: (N,C,H_out,W_out) where:
165+
166+
H_out = H_{in} * scale_{factor}
167+
W_out = W_{in} * scale_{factor}
168+
169+
170+
101171
For details of nearest neighbor interpolation, please refer to Wikipedia:
102172
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
103173

paddle/fluid/operators/interpolate_op.cu

Lines changed: 76 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ __global__ void KeNearestNeighborInterpFw(
2323
const T* in, const size_t in_img_h, const size_t in_img_w,
2424
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
2525
const size_t out_img_w, const size_t output_h, const size_t output_w,
26-
const size_t num_channels, const float ratio_h, const float ratio_w) {
26+
const size_t num_channels, const float ratio_h, const float ratio_w,
27+
const bool align_corners) {
2728
int nthreads = output_h * output_w;
2829
int tid = blockIdx.x * blockDim.x + threadIdx.x;
2930
int stride = blockDim.x * gridDim.x;
@@ -35,10 +36,14 @@ __global__ void KeNearestNeighborInterpFw(
3536
int channel_id = out_id_w / out_img_size;
3637

3738
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
38-
int in_img_idy = static_cast<int>(ratio_h * out_img_idy + 0.5);
39+
int in_img_idy = (align_corners)
40+
? static_cast<int>(ratio_h * out_img_idy + 0.5)
41+
: static_cast<int>(ratio_h * out_img_idy);
3942

4043
int out_img_idx = tid % out_img_w;
41-
int in_img_idx = static_cast<int>(ratio_w * out_img_idx + 0.5);
44+
int in_img_idx = (align_corners)
45+
? static_cast<int>(ratio_w * out_img_idx + 0.5)
46+
: static_cast<int>(ratio_w * out_img_idx);
4247

4348
out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
4449
in_img_idy * in_img_w + in_img_idx];
@@ -50,7 +55,8 @@ __global__ void KeNearestNeighborInterpBw(
5055
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
5156
const size_t input_w, const T* out, const size_t out_img_h,
5257
const size_t out_img_w, const size_t output_h, const size_t output_w,
53-
const size_t num_channels, const float ratio_h, const float ratio_w) {
58+
const size_t num_channels, const float ratio_h, const float ratio_w,
59+
const bool align_corners) {
5460
int nthreads = output_h * output_w;
5561
int tid = blockIdx.x * blockDim.x + threadIdx.x;
5662
int stride = blockDim.x * gridDim.x;
@@ -62,10 +68,14 @@ __global__ void KeNearestNeighborInterpBw(
6268
int channel_id = out_id_w / out_img_size;
6369

6470
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
65-
int in_img_idy = static_cast<int>(ratio_h * out_img_idy + 0.5);
71+
int in_img_idy = (align_corners)
72+
? static_cast<int>(ratio_h * out_img_idy + 0.5)
73+
: static_cast<int>(ratio_h * out_img_idy);
6674

6775
int out_img_idx = tid % out_img_w;
68-
int in_img_idx = static_cast<int>(ratio_w * out_img_idx + 0.5);
76+
int in_img_idx = (align_corners)
77+
? static_cast<int>(ratio_w * out_img_idx + 0.5)
78+
: static_cast<int>(ratio_w * out_img_idx);
6979

7080
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
7181
in_img_idy * in_img_w + in_img_idx];
@@ -79,10 +89,12 @@ __global__ void KeBilinearInterpFw(
7989
const T* in, const size_t in_img_h, const size_t in_img_w,
8090
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
8191
const size_t out_img_w, const size_t output_h, const size_t output_w,
82-
const size_t num_channels, const float ratio_h, const float ratio_w) {
92+
const size_t num_channels, const float ratio_h, const float ratio_w,
93+
const bool align_corners, const int align_mode) {
8394
int nthreads = output_h * output_w;
8495
int tid = blockIdx.x * blockDim.x + threadIdx.x;
8596
int stride = blockDim.x * gridDim.x;
97+
bool align_flag = (align_mode == 0 && !align_corners);
8698
for (; tid < nthreads; tid += stride) {
8799
int out_id_h = tid / output_w;
88100
int out_id_w = tid % output_w;
@@ -91,15 +103,23 @@ __global__ void KeBilinearInterpFw(
91103
int channel_id = out_id_w / out_img_size;
92104

93105
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
94-
int in_img_idy = ratio_h * out_img_idy;
106+
int in_img_idy = align_flag
107+
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
108+
: static_cast<int>(ratio_h * out_img_idy);
109+
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
95110
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
96-
T h1lambda = ratio_h * out_img_idy - in_img_idy;
111+
T h1lambda = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy
112+
: ratio_h * out_img_idy - in_img_idy;
97113
T h2lambda = 1.f - h1lambda;
98114

99115
int out_img_idx = tid % out_img_w;
100-
int in_img_idx = ratio_w * out_img_idx;
116+
int in_img_idx = align_flag
117+
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
118+
: static_cast<int>(ratio_w * out_img_idx);
119+
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
101120
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
102-
T w1lambda = ratio_w * out_img_idx - in_img_idx;
121+
T w1lambda = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx
122+
: ratio_w * out_img_idx - in_img_idx;
103123
T w2lambda = 1.f - w1lambda;
104124

105125
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
@@ -118,10 +138,12 @@ __global__ void KeBilinearInterpBw(
118138
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
119139
const size_t input_w, const T* out, const size_t out_img_h,
120140
const size_t out_img_w, const size_t output_h, const size_t output_w,
121-
const size_t num_channels, const T ratio_h, const T ratio_w) {
141+
const size_t num_channels, const T ratio_h, const T ratio_w,
142+
const bool align_corners, const int align_mode) {
122143
int nthreads = output_h * output_w;
123144
int tid = blockIdx.x * blockDim.x + threadIdx.x;
124145
int stride = blockDim.x * gridDim.x;
146+
bool align_flag = (align_mode == 0 && !align_corners);
125147
for (; tid < nthreads; tid += stride) {
126148
int out_id_h = tid / output_w;
127149
int out_id_w = tid % output_w;
@@ -130,15 +152,22 @@ __global__ void KeBilinearInterpBw(
130152
int channel_id = out_id_w / out_img_size;
131153

132154
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
133-
int in_img_idy = ratio_h * out_img_idy;
155+
int in_img_idy = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5
156+
: ratio_h * out_img_idy;
157+
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
134158
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
135-
T h1lambda = ratio_h * out_img_idy - in_img_idy;
159+
T h1lambda = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy
160+
: ratio_h * out_img_idy - in_img_idy;
161+
136162
T h2lambda = 1.f - h1lambda;
137163

138164
int out_img_idx = tid % out_img_w;
139-
int in_img_idx = ratio_w * out_img_idx;
165+
int in_img_idx = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5
166+
: ratio_w * out_img_idx;
167+
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
140168
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
141-
T w1lambda = ratio_w * out_img_idx - in_img_idx;
169+
T w1lambda = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx
170+
: ratio_w * out_img_idx - in_img_idx;
142171
T w2lambda = 1.f - w1lambda;
143172

144173
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
@@ -175,6 +204,9 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
175204
out_w = size_data[1];
176205
}
177206

207+
bool align_corners = ctx.Attr<bool>("align_corners");
208+
int align_mode = ctx.Attr<int>("align_mode");
209+
178210
int n = input->dims()[0];
179211
int c = input->dims()[1];
180212
int in_h = input->dims()[2];
@@ -188,10 +220,16 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
188220
int in_chw = c * in_hw;
189221
int out_chw = c * out_hw;
190222

191-
float ratio_h =
192-
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
193-
float ratio_w =
194-
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
223+
float ratio_h = 0.f;
224+
float ratio_w = 0.f;
225+
if (out_h > 1) {
226+
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
227+
: static_cast<float>(in_h) / out_h;
228+
}
229+
if (out_w > 1) {
230+
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
231+
: static_cast<float>(in_w) / out_w;
232+
}
195233

196234
if (in_h == out_h && in_w == out_w) {
197235
framework::TensorCopy(*input, ctx.GetPlace(), output);
@@ -206,12 +244,12 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
206244
KeNearestNeighborInterpFw<
207245
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
208246
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
209-
out_chw, c, ratio_h, ratio_w);
247+
out_chw, c, ratio_h, ratio_w, align_corners);
210248
} else if ("bilinear" == interp_method) {
211249
KeBilinearInterpFw<
212250
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
213251
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
214-
out_chw, c, ratio_h, ratio_w);
252+
out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
215253
}
216254
}
217255
};
@@ -234,6 +272,10 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
234272
int out_h = ctx.Attr<int>("out_h");
235273
int out_w = ctx.Attr<int>("out_w");
236274
auto out_size = ctx.Input<Tensor>("OutSize");
275+
276+
bool align_corners = ctx.Attr<bool>("align_corners");
277+
int align_mode = ctx.Attr<int>("align_mode");
278+
237279
if (out_size != nullptr) {
238280
Tensor sizes;
239281
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
@@ -252,10 +294,16 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
252294
int in_chw = c * in_hw;
253295
int out_chw = c * out_hw;
254296

255-
float ratio_h =
256-
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
257-
float ratio_w =
258-
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
297+
float ratio_h = 0.f;
298+
float ratio_w = 0.f;
299+
if (out_h > 1) {
300+
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
301+
: static_cast<float>(in_h) / out_h;
302+
}
303+
if (out_w > 1) {
304+
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
305+
: static_cast<float>(in_w) / out_w;
306+
}
259307

260308
if (in_h == out_h && in_w == out_w) {
261309
framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad);
@@ -270,12 +318,12 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
270318
KeNearestNeighborInterpBw<
271319
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
272320
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
273-
out_w, n, out_chw, c, ratio_h, ratio_w);
321+
out_w, n, out_chw, c, ratio_h, ratio_w, align_corners);
274322
} else if ("bilinear" == interp_method) {
275323
KeBilinearInterpBw<
276324
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
277325
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
278-
out_w, n, out_chw, c, ratio_h, ratio_w);
326+
out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
279327
}
280328
}
281329
};

0 commit comments

Comments
 (0)