Skip to content

Commit df4a354

Browse files
committed
nearest neighbor interp add cuda kernel. test=develop
1 parent 9755611 commit df4a354

File tree

5 files changed

+111
-93
lines changed

5 files changed

+111
-93
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], vararg
121121
paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR'))
122122
paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',))
123123
paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
124+
paddle.fluid.layers.resize_nearest ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
124125
paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None)
125126
paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,))
126127
paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,))

paddle/fluid/operators/nearest_neighbor_interp_op.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ class NearestNeighborInterpOp : public framework::OperatorWithKernel {
2525
protected:
2626
void InferShape(framework::InferShapeContext* ctx) const override {
2727
PADDLE_ENFORCE(ctx->HasInput("X"),
28-
"Input(X) of BilinearInterOp should not be null.");
28+
"Input(X) of NearestNeighborInterOp should not be null.");
2929
PADDLE_ENFORCE(ctx->HasOutput("Out"),
30-
"Output(Out) of BilinearInterOp should not be null.");
30+
"Output(Out) of NearestNeighborInterOp should not be null.");
3131

3232
auto dim_x = ctx->GetInputDim("X"); // NCHW format
3333
int out_h = ctx->Attrs().Get<int>("out_h");
@@ -64,8 +64,9 @@ class NearestNeighborInterpOpMaker : public framework::OpProtoAndCheckerMaker {
6464
.AsDispensable();
6565
AddOutput("Out", "The dimension of output is (N x C x out_h x out_w)");
6666

67-
AddAttr<int>("out_h", "output height of bilinear interpolation op.");
68-
AddAttr<int>("out_w", "output width of bilinear interpolation op.");
67+
AddAttr<int>("out_h",
68+
"output height of nearest neighbor interpolation op.");
69+
AddAttr<int>("out_w", "output width of nearest neighbor interpolation op.");
6970
AddComment(R"DOC(
7071
Nearest neighbor interpolation is to perform nearest neighbor interpolation
7172
in bot the 3rd dimention(in height direction) and the 4th dimention(in width

paddle/fluid/operators/nearest_neighbor_interp_op.cu

Lines changed: 63 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,14 @@
1515
namespace paddle {
1616
namespace operators {
1717

18-
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
19-
typename IndexType = Eigen::DenseIndex>
20-
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
2118
using framework::Tensor;
2219

2320
template <typename T>
24-
__global__ void KeBilinearInterpFw(
21+
__global__ void KeNearestNeighborInterpFw(
2522
const T* in, const size_t in_img_h, const size_t in_img_w,
2623
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
2724
const size_t out_img_w, const size_t output_h, const size_t output_w,
28-
const size_t num_channels, const T ratio_h, const T ratioW) {
25+
const size_t num_channels, const T ratio_h, const T ratio_w) {
2926
int nthreads = output_h * output_w;
3027
int tid = blockIdx.x * blockDim.x + threadIdx.x;
3128
if (tid < nthreads) {
@@ -36,34 +33,22 @@ __global__ void KeBilinearInterpFw(
3633
int channel_id = out_id_w / out_img_size;
3734

3835
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
39-
int in_img_idy = ratio_h * out_img_idy;
40-
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
41-
T h1lambda = ratio_h * out_img_idy - in_img_idy;
42-
T h2lambda = 1.f - h1lambda;
36+
int in_img_idy = static_cast<int>(round(ratio_h * out_img_idy));
4337

4438
int out_img_idx = tid % out_img_w;
45-
int in_img_idx = ratioW * out_img_idx;
46-
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
47-
T w1lambda = ratioW * out_img_idx - in_img_idx;
48-
T w2lambda = 1.f - w1lambda;
49-
50-
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
51-
in_img_idy * in_img_w + in_img_idx];
52-
53-
// bilinear interpolation
54-
out[out_id_h * output_w + out_id_w] =
55-
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
56-
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
57-
w1lambda * in_pos[h_id * in_img_w + w_id]);
39+
int in_img_idx = static_cast<int>(round(ratio_w * out_img_idx));
40+
41+
out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
42+
in_img_idy * in_img_w + in_img_idx];
5843
}
5944
}
6045

6146
template <typename T>
62-
__global__ void KeBilinearInterpBw(
47+
__global__ void KeNearestNeighborInterpBw(
6348
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
6449
const size_t input_w, const T* out, const size_t out_img_h,
6550
const size_t out_img_w, const size_t output_h, const size_t output_w,
66-
const size_t num_channels, const T ratio_h, const T ratioW) {
51+
const size_t num_channels, const T ratio_h, const T ratio_w) {
6752
int nthreads = output_h * output_w;
6853
int tid = blockIdx.x * blockDim.x + threadIdx.x;
6954
if (tid < nthreads) {
@@ -74,25 +59,15 @@ __global__ void KeBilinearInterpBw(
7459
int channel_id = out_id_w / out_img_size;
7560

7661
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
77-
int in_img_idy = ratio_h * out_img_idy;
78-
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
79-
T h1lambda = ratio_h * out_img_idy - in_img_idy;
80-
T h2lambda = 1.f - h1lambda;
62+
int in_img_idy = static_cast<int>(round(ratio_h * out_img_idy));
8163

8264
int out_img_idx = tid % out_img_w;
83-
int in_img_idx = ratioW * out_img_idx;
84-
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
85-
T w1lambda = ratioW * out_img_idx - in_img_idx;
86-
T w2lambda = 1.f - w1lambda;
65+
int in_img_idx = static_cast<int>(round(ratio_w * out_img_idx));
8766

8867
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
8968
in_img_idy * in_img_w + in_img_idx];
90-
const T* out_pos = &out[out_id_h * output_w + out_id_w];
91-
atomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
92-
atomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]);
93-
atomicAdd(&in_pos[h_id * in_img_w], h1lambda * w2lambda * out_pos[0]);
94-
atomicAdd(&in_pos[h_id * in_img_w + w_id],
95-
h1lambda * w1lambda * out_pos[0]);
69+
const T out_pos = out[out_id_h * output_w + out_id_w];
70+
atomicAdd(in_pos, out_pos);
9671
}
9772
}
9873

@@ -102,101 +77,103 @@ class NearestNeighborInterpOpCUDAKernel : public framework::OpKernel<T> {
10277
void Compute(const framework::ExecutionContext& ctx) const override {
10378
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
10479
"This kernel only runs on GPU device.");
105-
auto* input_t = ctx.Input<Tensor>("X"); // float tensor
106-
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
107-
auto* input = input_t->data<T>();
80+
auto* input = ctx.Input<Tensor>("X"); // float tensor
81+
auto* output = ctx.Output<Tensor>("Out"); // float tensor
82+
auto* input_data = input->data<T>();
10883

10984
int out_h = ctx.Attr<int>("out_h");
11085
int out_w = ctx.Attr<int>("out_w");
111-
auto out_dims = output_t->dims();
112-
auto out_size_t = ctx.Input<Tensor>("OutSize");
113-
if (out_size_t != nullptr) {
86+
auto out_size = ctx.Input<Tensor>("OutSize");
87+
if (out_size != nullptr) {
11488
Tensor sizes;
115-
framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes);
89+
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
11690
auto size_data = sizes.data<int>();
11791
out_h = size_data[0];
11892
out_w = size_data[1];
11993
}
120-
auto* output = output_t->mutable_data<T>(
121-
{out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace());
12294

123-
int batch_size = input_t->dims()[0];
124-
int channels = input_t->dims()[1];
125-
int in_h = input_t->dims()[2];
126-
int in_w = input_t->dims()[3];
95+
int n = input->dims()[0];
96+
int c = input->dims()[1];
97+
int in_h = input->dims()[2];
98+
int in_w = input->dims()[3];
99+
100+
auto* output_data =
101+
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
127102

128103
int in_hw = in_h * in_w;
129104
int out_hw = out_h * out_w;
130-
int in_chw = channels * in_hw;
131-
int out_chw = channels * out_hw;
105+
int in_chw = c * in_hw;
106+
int out_chw = c * out_hw;
132107

133108
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
134109
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
135110

136111
if (in_h == out_h && in_w == out_w) {
137-
memcpy(output, input, input_t->numel() * sizeof(T));
138-
} else {
139-
int threadNum = batch_size * out_chw;
140-
int blocks = (threadNum + 1024 - 1) / 1024;
141-
142-
KeBilinearInterpFw<
143-
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
144-
input, in_h, in_w, batch_size, in_chw, output, out_h, out_w,
145-
batch_size, out_chw, channels, ratio_h, ratio_w);
112+
memcpy(output_data, input_data, input->numel() * sizeof(T));
113+
return;
146114
}
115+
116+
int threadNum = n * out_chw;
117+
int blocks = (threadNum + 1024 - 1) / 1024;
118+
119+
KeNearestNeighborInterpFw<
120+
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
121+
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
122+
out_chw, c, ratio_h, ratio_w);
147123
}
148124
};
149125

150126
template <typename T>
151127
class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel<T> {
152128
public:
153129
void Compute(const framework::ExecutionContext& ctx) const override {
154-
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
155-
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
156-
auto* d_output = d_output_t->data<T>();
157-
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
130+
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
131+
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
132+
auto* output_grad_data = output_grad->data<T>();
133+
auto* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
158134

159135
auto& device_ctx =
160136
ctx.template device_context<platform::CUDADeviceContext>();
161137
math::SetConstant<platform::CUDADeviceContext, T> zero;
162-
zero(device_ctx, d_input_t, static_cast<T>(0.0));
138+
zero(device_ctx, input_grad, static_cast<T>(0.0));
163139

164140
int out_h = ctx.Attr<int>("out_h");
165141
int out_w = ctx.Attr<int>("out_w");
166142

167-
auto out_size_t = ctx.Input<Tensor>("OutSize");
168-
if (out_size_t != nullptr) {
143+
auto out_size = ctx.Input<Tensor>("OutSize");
144+
if (out_size != nullptr) {
169145
Tensor sizes;
170-
framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes);
146+
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
171147
auto size_data = sizes.data<int>();
172148
out_h = size_data[0];
173149
out_w = size_data[1];
174150
}
175151

176-
int batch_size = d_input_t->dims()[0];
177-
int channels = d_input_t->dims()[1];
178-
int in_h = d_input_t->dims()[2];
179-
int in_w = d_input_t->dims()[3];
152+
int n = input_grad->dims()[0];
153+
int c = input_grad->dims()[1];
154+
int in_h = input_grad->dims()[2];
155+
int in_w = input_grad->dims()[3];
180156

181157
int in_hw = in_h * in_w;
182158
int out_hw = out_h * out_w;
183-
int in_chw = channels * in_hw;
184-
int out_chw = channels * out_hw;
159+
int in_chw = c * in_hw;
160+
int out_chw = c * out_hw;
185161

186162
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
187163
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
188164

189165
if (in_h == out_h && in_w == out_w) {
190-
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
191-
} else {
192-
int threadNum = batch_size * out_chw;
193-
int blocks = (threadNum + 1024 - 1) / 1024;
194-
195-
KeBilinearInterpBw<
196-
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
197-
d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w,
198-
batch_size, out_chw, channels, ratio_h, ratio_w);
166+
memcpy(input_grad, output_grad, input_grad->numel() * sizeof(T));
167+
return;
199168
}
169+
170+
int threadNum = n * out_chw;
171+
int blocks = (threadNum + 1024 - 1) / 1024;
172+
173+
KeNearestNeighborInterpBw<
174+
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
175+
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
176+
n, out_chw, c, ratio_h, ratio_w);
200177
}
201178
};
202179

@@ -206,5 +183,5 @@ class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel<T> {
206183
namespace ops = paddle::operators;
207184
REGISTER_OP_CUDA_KERNEL(nearest_neighbor_interp,
208185
ops::NearestNeighborInterpOpCUDAKernel<float>);
209-
REGISTER_OP_CUDA_KERNEL(nearest_neighborinterp_grad,
186+
REGISTER_OP_CUDA_KERNEL(nearest_neighbor_interp_grad,
210187
ops::NearestNeighborInterpGradOpCUDAKernel<float>);

python/paddle/fluid/layers/nn.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
'image_resize',
102102
'image_resize_short',
103103
'resize_bilinear',
104+
'resize_nearest',
104105
'gather',
105106
'scatter',
106107
'sequence_scatter',
@@ -5584,6 +5585,7 @@ def image_resize(input,
55845585
Supporting resample methods:
55855586
55865587
'BILINEAR' : Bilinear interpolation
5588+
'NEAREST' : Nearest neighbor interpolation
55875589
55885590
Args:
55895591
input (Variable): The input tensor of image resize layer,
@@ -5610,13 +5612,17 @@ def image_resize(input,
56105612
56115613
out = fluid.layers.image_resize(input, out_shape=[12, 12])
56125614
"""
5613-
resample_methods = {'BILINEAR': 'bilinear_interp'}
5615+
resample_methods = {
5616+
'BILINEAR': 'bilinear_interp',
5617+
'NEAREST': 'nearest_neighbor_interp'
5618+
}
56145619
if resample not in resample_methods:
56155620
raise ValueError(
5616-
"The 'resample' of image_resize can only be 'BILINEAR' currently.")
5621+
"The 'resample' of image_resize can only be 'BILINEAR' and 'NEAREST' currently."
5622+
)
56175623
if out_shape is None and scale is None:
56185624
raise ValueError("One of out_shape and scale must not be None")
5619-
helper = LayerHelper('bilinear_interp', **locals())
5625+
helper = LayerHelper(resample_methods[resample], **locals())
56205626
dtype = helper.input_dtype()
56215627

56225628
def _is_list_or_turple_(data):
@@ -5672,6 +5678,29 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
56725678
return image_resize(input, out_shape, scale, name, 'BILINEAR')
56735679

56745680

5681+
@templatedoc(op_type="bilinear_interp")
5682+
def resize_nearest(input, out_shape=None, scale=None, name=None):
5683+
"""
5684+
${comment}
5685+
5686+
Args:
5687+
input(${x_type}): ${x_comment}.
5688+
5689+
out_shape(${out_size_type}): ${out_size_comment}.
5690+
5691+
scale(float|None): The multiplier for the input height or width. At
5692+
least one of out_shape or scale must be set. And out_shape has
5693+
a higher priority than scale. Default: None.
5694+
5695+
name(str|None): The output variable name.
5696+
5697+
Returns:
5698+
${out_comment}.
5699+
"""
5700+
5701+
return image_resize(input, out_shape, scale, name, 'NEAREST')
5702+
5703+
56755704
def image_resize_short(input, out_short_len, resample='BILINEAR'):
56765705
"""
56775706
Resize a batch of images. The short edge of input images will be

python/paddle/fluid/tests/unittests/test_layers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,16 @@ def test_resize_bilinear(self):
485485
self.assertIsNotNone(output)
486486
print(str(program))
487487

488+
def test_resize_bilinear(self):
489+
program = Program()
490+
with program_guard(program):
491+
x = layers.data(name='x', shape=[3, 9, 6], dtype="float32")
492+
output = layers.resize_nearest(x, out_shape=[12, 12])
493+
self.assertIsNotNone(output)
494+
output = layers.resize_nearest(x, scale=3)
495+
self.assertIsNotNone(output)
496+
print(str(program))
497+
488498
def test_polygon_box_transform(self):
489499
program = Program()
490500
with program_guard(program):

0 commit comments

Comments
 (0)