Skip to content

Commit 251a1bb

Browse files
authored
Merge pull request #14588 from heavengate/revert_interpolate
fix interpolate_op incompatible. test=develop
2 parents 3ae6692 + bb489d4 commit 251a1bb

File tree

6 files changed

+252
-170
lines changed

6 files changed

+252
-170
lines changed

paddle/fluid/operators/interpolate_op.cc

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,12 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
7676

7777
AddAttr<int>("out_h", "output height of interpolate op.");
7878
AddAttr<int>("out_w", "output width of interpolate op.");
79-
AddAttr<std::string>(
80-
"interp_method",
81-
"(string), interpolation method, can be \"bilinear\" for "
82-
"bilinear interpolation and \"nearest\" for nearest "
83-
"neighbor interpolation.");
79+
AddAttr<std::string>("interp_method",
80+
"(string, default \"bilinear\"), interpolation "
81+
"method, can be \"bilinear\" for "
82+
"bilinear interpolation and \"nearest\" for nearest "
83+
"neighbor interpolation.")
84+
.SetDefault("bilinear");
8485
AddComment(R"DOC(
8586
This operator samples input X to given output shape by using specified
8687
interpolation method, the interpolation methods can be \"nearest\"
@@ -132,11 +133,19 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
132133
} // namespace paddle
133134

134135
namespace ops = paddle::operators;
135-
REGISTER_OPERATOR(interpolate, ops::InterpolateOp, ops::InterpolateOpMaker,
136+
REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
136137
paddle::framework::DefaultGradOpDescMaker<true>);
137-
REGISTER_OPERATOR(interpolate_grad, ops::InterpolateOpGrad);
138-
REGISTER_OP_CPU_KERNEL(interpolate, ops::InterpolateKernel<float>,
138+
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad);
139+
REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
140+
paddle::framework::DefaultGradOpDescMaker<true>);
141+
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad);
142+
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
143+
ops::InterpolateKernel<double>,
144+
ops::InterpolateKernel<uint8_t>);
145+
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, ops::InterpolateGradKernel<float>,
146+
ops::InterpolateGradKernel<double>);
147+
REGISTER_OP_CPU_KERNEL(nearest_interp, ops::InterpolateKernel<float>,
139148
ops::InterpolateKernel<double>,
140149
ops::InterpolateKernel<uint8_t>);
141-
REGISTER_OP_CPU_KERNEL(interpolate_grad, ops::InterpolateGradKernel<float>,
150+
REGISTER_OP_CPU_KERNEL(nearest_interp_grad, ops::InterpolateGradKernel<float>,
142151
ops::InterpolateGradKernel<double>);

paddle/fluid/operators/interpolate_op.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,9 +284,15 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
284284
} // namespace paddle
285285

286286
namespace ops = paddle::operators;
287-
REGISTER_OP_CUDA_KERNEL(interpolate, ops::InterpolateOpCUDAKernel<float>,
287+
REGISTER_OP_CUDA_KERNEL(bilinear_interp, ops::InterpolateOpCUDAKernel<float>,
288288
ops::InterpolateOpCUDAKernel<double>,
289289
ops::InterpolateOpCUDAKernel<int>);
290-
REGISTER_OP_CUDA_KERNEL(interpolate_grad,
290+
REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad,
291+
ops::InterpolateGradOpCUDAKernel<float>,
292+
ops::InterpolateGradOpCUDAKernel<double>);
293+
REGISTER_OP_CUDA_KERNEL(nearest_interp, ops::InterpolateOpCUDAKernel<float>,
294+
ops::InterpolateOpCUDAKernel<double>,
295+
ops::InterpolateOpCUDAKernel<int>);
296+
REGISTER_OP_CUDA_KERNEL(nearest_interp_grad,
291297
ops::InterpolateGradOpCUDAKernel<float>,
292298
ops::InterpolateGradOpCUDAKernel<double>);

python/paddle/fluid/layers/nn.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5870,9 +5870,10 @@ def image_resize(input,
58705870
raise ValueError(
58715871
"The 'resample' of image_resize can only be 'BILINEAR' or 'NEAREST' currently."
58725872
)
5873+
resample_type = resample_methods[resample]
58735874
if out_shape is None and scale is None:
58745875
raise ValueError("One of out_shape and scale must not be None.")
5875-
helper = LayerHelper('interpolate', **locals())
5876+
helper = LayerHelper('{}_interp'.format(resample_type), **locals())
58765877
dtype = helper.input_dtype()
58775878

58785879
def _is_list_or_turple_(data):
@@ -5906,18 +5907,16 @@ def _is_list_or_turple_(data):
59065907

59075908
out = helper.create_variable_for_type_inference(dtype)
59085909
helper.append_op(
5909-
type='interpolate',
5910+
type='{}_interp'.format(resample_type),
59105911
inputs=inputs,
59115912
outputs={"Out": out},
5912-
attrs={
5913-
"out_h": out_h,
5914-
"out_w": out_w,
5915-
"interp_method": resample_methods[resample]
5916-
})
5913+
attrs={"out_h": out_h,
5914+
"out_w": out_w,
5915+
"interp_method": resample_type})
59175916
return out
59185917

59195918

5920-
@templatedoc(op_type="interpolate")
5919+
@templatedoc(op_type="bilinear_interp")
59215920
def resize_bilinear(input,
59225921
out_shape=None,
59235922
scale=None,
@@ -5973,7 +5972,7 @@ def resize_bilinear(input,
59735972
return image_resize(input, out_shape, scale, name, 'BILINEAR', actual_shape)
59745973

59755974

5976-
@templatedoc(op_type="interpolate")
5975+
@templatedoc(op_type="nearest_interp")
59775976
def resize_nearest(input,
59785977
out_shape=None,
59795978
scale=None,

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,14 @@ list(REMOVE_ITEM TEST_OPS test_dist_se_resnext)
8181
list(REMOVE_ITEM TEST_OPS test_dist_transformer)
8282
list(REMOVE_ITEM TEST_OPS test_parallel_executor_transformer)
8383
list(REMOVE_ITEM TEST_OPS test_image_classification_resnet)
84-
list(REMOVE_ITEM TEST_OPS test_interpolate_op)
84+
list(REMOVE_ITEM TEST_OPS test_bilinear_interp_op)
85+
list(REMOVE_ITEM TEST_OPS test_nearest_interp_op)
8586
foreach(TEST_OP ${TEST_OPS})
8687
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
8788
endforeach(TEST_OP)
8889
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL)
89-
py_test_modules(test_interpolate_op MODULES test_interpolate_op SERIAL)
90+
py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op SERIAL)
91+
py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op SERIAL)
9092
if(WITH_DISTRIBUTE)
9193
py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
9294
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)

python/paddle/fluid/tests/unittests/test_interpolate_op.py renamed to python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py

Lines changed: 17 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,6 @@
2020
import paddle.fluid.core as core
2121

2222

23-
def nearest_neighbor_interp_np(X,
24-
out_h,
25-
out_w,
26-
out_size=None,
27-
actual_shape=None):
28-
"""nearest neighbor interpolation implement in shape [N, C, H, W]"""
29-
if out_size is not None:
30-
out_h = out_size[0]
31-
out_w = out_size[1]
32-
if actual_shape is not None:
33-
out_h = actual_shape[0]
34-
out_w = actual_shape[1]
35-
n, c, in_h, in_w = X.shape
36-
37-
ratio_h = ratio_w = 0.0
38-
if out_h > 1:
39-
ratio_h = (in_h - 1.0) / (out_h - 1.0)
40-
if out_w > 1:
41-
ratio_w = (in_w - 1.0) / (out_w - 1.0)
42-
43-
out = np.zeros((n, c, out_h, out_w))
44-
for i in range(out_h):
45-
in_i = int(ratio_h * i + 0.5)
46-
for j in range(out_w):
47-
in_j = int(ratio_w * j + 0.5)
48-
out[:, :, i, j] = X[:, :, in_i, in_j]
49-
50-
return out.astype(X.dtype)
51-
52-
5323
def bilinear_interp_np(input, out_h, out_w, out_size=None, actual_shape=None):
5424
"""bilinear interpolation implement in shape [N, C, H, W]"""
5525
if out_size is not None:
@@ -87,22 +57,16 @@ def bilinear_interp_np(input, out_h, out_w, out_size=None, actual_shape=None):
8757
return out.astype(input.dtype)
8858

8959

90-
INTERPOLATE_FUNCS = {
91-
'bilinear': bilinear_interp_np,
92-
'nearest': nearest_neighbor_interp_np,
93-
}
94-
95-
96-
class TestInterpolateOp(OpTest):
60+
class TestBilinearInterpOp(OpTest):
9761
def setUp(self):
9862
self.out_size = None
9963
self.actual_shape = None
10064
self.init_test_case()
101-
self.op_type = "interpolate"
65+
self.op_type = "bilinear_interp"
10266
input_np = np.random.random(self.input_shape).astype("float32")
10367

104-
output_np = INTERPOLATE_FUNCS[self.interp_method](
105-
input_np, self.out_h, self.out_w, self.out_size, self.actual_shape)
68+
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w,
69+
self.out_size, self.actual_shape)
10670
self.inputs = {'X': input_np}
10771
if self.out_size is not None:
10872
self.inputs['OutSize'] = self.out_size
@@ -129,31 +93,31 @@ def init_test_case(self):
12993
self.out_size = np.array([3, 3]).astype("int32")
13094

13195

132-
class TestBilinearInterpCase1(TestInterpolateOp):
96+
class TestBilinearInterpCase1(TestBilinearInterpOp):
13397
def init_test_case(self):
13498
self.interp_method = 'bilinear'
13599
self.input_shape = [4, 1, 7, 8]
136100
self.out_h = 1
137101
self.out_w = 1
138102

139103

140-
class TestBilinearInterpCase2(TestInterpolateOp):
104+
class TestBilinearInterpCase2(TestBilinearInterpOp):
141105
def init_test_case(self):
142106
self.interp_method = 'bilinear'
143107
self.input_shape = [3, 3, 9, 6]
144108
self.out_h = 12
145109
self.out_w = 12
146110

147111

148-
class TestBilinearInterpCase3(TestInterpolateOp):
112+
class TestBilinearInterpCase3(TestBilinearInterpOp):
149113
def init_test_case(self):
150114
self.interp_method = 'bilinear'
151115
self.input_shape = [1, 1, 128, 64]
152116
self.out_h = 64
153117
self.out_w = 128
154118

155119

156-
class TestBilinearInterpCase4(TestInterpolateOp):
120+
class TestBilinearInterpCase4(TestBilinearInterpOp):
157121
def init_test_case(self):
158122
self.interp_method = 'bilinear'
159123
self.input_shape = [4, 1, 7, 8]
@@ -162,7 +126,7 @@ def init_test_case(self):
162126
self.out_size = np.array([2, 2]).astype("int32")
163127

164128

165-
class TestBilinearInterpCase5(TestInterpolateOp):
129+
class TestBilinearInterpCase5(TestBilinearInterpOp):
166130
def init_test_case(self):
167131
self.interp_method = 'bilinear'
168132
self.input_shape = [3, 3, 9, 6]
@@ -171,7 +135,7 @@ def init_test_case(self):
171135
self.out_size = np.array([11, 11]).astype("int32")
172136

173137

174-
class TestBilinearInterpCase6(TestInterpolateOp):
138+
class TestBilinearInterpCase6(TestBilinearInterpOp):
175139
def init_test_case(self):
176140
self.interp_method = 'bilinear'
177141
self.input_shape = [1, 1, 128, 64]
@@ -180,7 +144,7 @@ def init_test_case(self):
180144
self.out_size = np.array([65, 129]).astype("int32")
181145

182146

183-
class TestBilinearInterpActualShape(TestInterpolateOp):
147+
class TestBilinearInterpActualShape(TestBilinearInterpOp):
184148
def init_test_case(self):
185149
self.interp_method = 'bilinear'
186150
self.input_shape = [3, 2, 32, 16]
@@ -189,25 +153,16 @@ def init_test_case(self):
189153
self.out_size = np.array([66, 40]).astype("int32")
190154

191155

192-
class TestBilinearInterpBigScale(TestInterpolateOp):
193-
def init_test_case(self):
194-
self.interp_method = 'bilinear'
195-
self.input_shape = [4, 4, 64, 32]
196-
self.out_h = 100
197-
self.out_w = 50
198-
self.out_size = np.array([101, 51]).astype('int32')
199-
200-
201-
class TestInterpolateOpUint8(OpTest):
156+
class TestBilinearInterpOpUint8(OpTest):
202157
def setUp(self):
203158
self.out_size = None
204159
self.actual_shape = None
205160
self.init_test_case()
206-
self.op_type = "interpolate"
161+
self.op_type = "bilinear_interp"
207162
input_np = np.random.randint(
208163
low=0, high=256, size=self.input_shape).astype("uint8")
209-
output_np = INTERPOLATE_FUNCS[self.interp_method](
210-
input_np, self.out_h, self.out_w, self.out_size, self.actual_shape)
164+
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w,
165+
self.out_size, self.actual_shape)
211166
self.inputs = {'X': input_np}
212167
if self.out_size is not None:
213168
self.inputs['OutSize'] = self.out_size
@@ -228,15 +183,15 @@ def init_test_case(self):
228183
self.out_w = 9
229184

230185

231-
class TestBilinearInterpCase1Uint8(TestInterpolateOpUint8):
186+
class TestBilinearInterpCase1Uint8(TestBilinearInterpOpUint8):
232187
def init_test_case(self):
233188
self.interp_method = 'bilinear'
234189
self.input_shape = [2, 3, 128, 64]
235190
self.out_h = 120
236191
self.out_w = 50
237192

238193

239-
class TestBilinearInterpCase2Uint8(TestInterpolateOpUint8):
194+
class TestBilinearInterpCase2Uint8(TestBilinearInterpOpUint8):
240195
def init_test_case(self):
241196
self.interp_method = 'bilinear'
242197
self.input_shape = [4, 1, 7, 8]
@@ -245,91 +200,5 @@ def init_test_case(self):
245200
self.out_size = np.array([6, 15]).astype("int32")
246201

247202

248-
class TestNearestNeighborInterpCase1(TestInterpolateOp):
249-
def init_test_case(self):
250-
self.interp_method = 'nearest'
251-
self.input_shape = [4, 1, 7, 8]
252-
self.out_h = 1
253-
self.out_w = 1
254-
255-
256-
class TestNearestNeighborInterpCase2(TestInterpolateOp):
257-
def init_test_case(self):
258-
self.interp_method = 'nearest'
259-
self.input_shape = [3, 3, 9, 6]
260-
self.out_h = 12
261-
self.out_w = 12
262-
263-
264-
class TestNearestNeighborInterpCase3(TestInterpolateOp):
265-
def init_test_case(self):
266-
self.interp_method = 'nearest'
267-
self.input_shape = [1, 1, 128, 64]
268-
self.out_h = 64
269-
self.out_w = 128
270-
271-
272-
class TestNearestNeighborInterpCase4(TestInterpolateOp):
273-
def init_test_case(self):
274-
self.interp_method = 'nearest'
275-
self.input_shape = [4, 1, 7, 8]
276-
self.out_h = 1
277-
self.out_w = 1
278-
self.out_size = np.array([2, 2]).astype("int32")
279-
280-
281-
class TestNearestNeighborInterpCase5(TestInterpolateOp):
282-
def init_test_case(self):
283-
self.interp_method = 'nearest'
284-
self.input_shape = [3, 3, 9, 6]
285-
self.out_h = 12
286-
self.out_w = 12
287-
self.out_size = np.array([11, 11]).astype("int32")
288-
289-
290-
class TestNearestNeighborInterpCase6(TestInterpolateOp):
291-
def init_test_case(self):
292-
self.interp_method = 'nearest'
293-
self.input_shape = [1, 1, 128, 64]
294-
self.out_h = 64
295-
self.out_w = 128
296-
self.out_size = np.array([65, 129]).astype("int32")
297-
298-
299-
class TestNearestNeighborInterpActualShape(TestInterpolateOp):
300-
def init_test_case(self):
301-
self.interp_method = 'nearest'
302-
self.input_shape = [3, 2, 32, 16]
303-
self.out_h = 64
304-
self.out_w = 32
305-
self.out_size = np.array([66, 40]).astype("int32")
306-
307-
308-
class TestNearestNeighborInterpBigScale(TestInterpolateOp):
309-
def init_test_case(self):
310-
self.interp_method = 'nearest'
311-
self.input_shape = [4, 4, 64, 32]
312-
self.out_h = 100
313-
self.out_w = 50
314-
self.out_size = np.array([101, 51]).astype('int32')
315-
316-
317-
class TestNearestNeighborInterpCase1Uint8(TestInterpolateOpUint8):
318-
def init_test_case(self):
319-
self.interp_method = 'nearest'
320-
self.input_shape = [2, 3, 128, 64]
321-
self.out_h = 120
322-
self.out_w = 50
323-
324-
325-
class TestNearestNeighborInterpCase2Uint8(TestInterpolateOpUint8):
326-
def init_test_case(self):
327-
self.interp_method = 'nearest'
328-
self.input_shape = [4, 1, 7, 8]
329-
self.out_h = 5
330-
self.out_w = 13
331-
self.out_size = np.array([6, 15]).astype("int32")
332-
333-
334203
if __name__ == "__main__":
335204
unittest.main()

0 commit comments

Comments
 (0)