20
20
import paddle .fluid .core as core
21
21
22
22
23
- def nearest_neighbor_interp_np (X ,
24
- out_h ,
25
- out_w ,
26
- out_size = None ,
27
- actual_shape = None ):
28
- """nearest neighbor interpolation implement in shape [N, C, H, W]"""
29
- if out_size is not None :
30
- out_h = out_size [0 ]
31
- out_w = out_size [1 ]
32
- if actual_shape is not None :
33
- out_h = actual_shape [0 ]
34
- out_w = actual_shape [1 ]
35
- n , c , in_h , in_w = X .shape
36
-
37
- ratio_h = ratio_w = 0.0
38
- if out_h > 1 :
39
- ratio_h = (in_h - 1.0 ) / (out_h - 1.0 )
40
- if out_w > 1 :
41
- ratio_w = (in_w - 1.0 ) / (out_w - 1.0 )
42
-
43
- out = np .zeros ((n , c , out_h , out_w ))
44
- for i in range (out_h ):
45
- in_i = int (ratio_h * i + 0.5 )
46
- for j in range (out_w ):
47
- in_j = int (ratio_w * j + 0.5 )
48
- out [:, :, i , j ] = X [:, :, in_i , in_j ]
49
-
50
- return out .astype (X .dtype )
51
-
52
-
53
23
def bilinear_interp_np (input , out_h , out_w , out_size = None , actual_shape = None ):
54
24
"""bilinear interpolation implement in shape [N, C, H, W]"""
55
25
if out_size is not None :
@@ -87,22 +57,16 @@ def bilinear_interp_np(input, out_h, out_w, out_size=None, actual_shape=None):
87
57
return out .astype (input .dtype )
88
58
89
59
90
- INTERPOLATE_FUNCS = {
91
- 'bilinear' : bilinear_interp_np ,
92
- 'nearest' : nearest_neighbor_interp_np ,
93
- }
94
-
95
-
96
- class TestInterpolateOp (OpTest ):
60
+ class TestBilinearInterpOp (OpTest ):
97
61
def setUp (self ):
98
62
self .out_size = None
99
63
self .actual_shape = None
100
64
self .init_test_case ()
101
- self .op_type = "interpolate "
65
+ self .op_type = "bilinear_interp "
102
66
input_np = np .random .random (self .input_shape ).astype ("float32" )
103
67
104
- output_np = INTERPOLATE_FUNCS [ self .interp_method ](
105
- input_np , self . out_h , self . out_w , self .out_size , self .actual_shape )
68
+ output_np = bilinear_interp_np ( input_np , self .out_h , self . out_w ,
69
+ self .out_size , self .actual_shape )
106
70
self .inputs = {'X' : input_np }
107
71
if self .out_size is not None :
108
72
self .inputs ['OutSize' ] = self .out_size
@@ -129,31 +93,31 @@ def init_test_case(self):
129
93
self .out_size = np .array ([3 , 3 ]).astype ("int32" )
130
94
131
95
132
- class TestBilinearInterpCase1 (TestInterpolateOp ):
96
+ class TestBilinearInterpCase1 (TestBilinearInterpOp ):
133
97
def init_test_case (self ):
134
98
self .interp_method = 'bilinear'
135
99
self .input_shape = [4 , 1 , 7 , 8 ]
136
100
self .out_h = 1
137
101
self .out_w = 1
138
102
139
103
140
- class TestBilinearInterpCase2 (TestInterpolateOp ):
104
+ class TestBilinearInterpCase2 (TestBilinearInterpOp ):
141
105
def init_test_case (self ):
142
106
self .interp_method = 'bilinear'
143
107
self .input_shape = [3 , 3 , 9 , 6 ]
144
108
self .out_h = 12
145
109
self .out_w = 12
146
110
147
111
148
- class TestBilinearInterpCase3 (TestInterpolateOp ):
112
+ class TestBilinearInterpCase3 (TestBilinearInterpOp ):
149
113
def init_test_case (self ):
150
114
self .interp_method = 'bilinear'
151
115
self .input_shape = [1 , 1 , 128 , 64 ]
152
116
self .out_h = 64
153
117
self .out_w = 128
154
118
155
119
156
- class TestBilinearInterpCase4 (TestInterpolateOp ):
120
+ class TestBilinearInterpCase4 (TestBilinearInterpOp ):
157
121
def init_test_case (self ):
158
122
self .interp_method = 'bilinear'
159
123
self .input_shape = [4 , 1 , 7 , 8 ]
@@ -162,7 +126,7 @@ def init_test_case(self):
162
126
self .out_size = np .array ([2 , 2 ]).astype ("int32" )
163
127
164
128
165
- class TestBilinearInterpCase5 (TestInterpolateOp ):
129
+ class TestBilinearInterpCase5 (TestBilinearInterpOp ):
166
130
def init_test_case (self ):
167
131
self .interp_method = 'bilinear'
168
132
self .input_shape = [3 , 3 , 9 , 6 ]
@@ -171,7 +135,7 @@ def init_test_case(self):
171
135
self .out_size = np .array ([11 , 11 ]).astype ("int32" )
172
136
173
137
174
- class TestBilinearInterpCase6 (TestInterpolateOp ):
138
+ class TestBilinearInterpCase6 (TestBilinearInterpOp ):
175
139
def init_test_case (self ):
176
140
self .interp_method = 'bilinear'
177
141
self .input_shape = [1 , 1 , 128 , 64 ]
@@ -180,7 +144,7 @@ def init_test_case(self):
180
144
self .out_size = np .array ([65 , 129 ]).astype ("int32" )
181
145
182
146
183
- class TestBilinearInterpActualShape (TestInterpolateOp ):
147
+ class TestBilinearInterpActualShape (TestBilinearInterpOp ):
184
148
def init_test_case (self ):
185
149
self .interp_method = 'bilinear'
186
150
self .input_shape = [3 , 2 , 32 , 16 ]
@@ -189,25 +153,16 @@ def init_test_case(self):
189
153
self .out_size = np .array ([66 , 40 ]).astype ("int32" )
190
154
191
155
192
- class TestBilinearInterpBigScale (TestInterpolateOp ):
193
- def init_test_case (self ):
194
- self .interp_method = 'bilinear'
195
- self .input_shape = [4 , 4 , 64 , 32 ]
196
- self .out_h = 100
197
- self .out_w = 50
198
- self .out_size = np .array ([101 , 51 ]).astype ('int32' )
199
-
200
-
201
- class TestInterpolateOpUint8 (OpTest ):
156
+ class TestBilinearInterpOpUint8 (OpTest ):
202
157
def setUp (self ):
203
158
self .out_size = None
204
159
self .actual_shape = None
205
160
self .init_test_case ()
206
- self .op_type = "interpolate "
161
+ self .op_type = "bilinear_interp "
207
162
input_np = np .random .randint (
208
163
low = 0 , high = 256 , size = self .input_shape ).astype ("uint8" )
209
- output_np = INTERPOLATE_FUNCS [ self .interp_method ](
210
- input_np , self . out_h , self . out_w , self .out_size , self .actual_shape )
164
+ output_np = bilinear_interp_np ( input_np , self .out_h , self . out_w ,
165
+ self .out_size , self .actual_shape )
211
166
self .inputs = {'X' : input_np }
212
167
if self .out_size is not None :
213
168
self .inputs ['OutSize' ] = self .out_size
@@ -228,15 +183,15 @@ def init_test_case(self):
228
183
self .out_w = 9
229
184
230
185
231
- class TestBilinearInterpCase1Uint8 (TestInterpolateOpUint8 ):
186
+ class TestBilinearInterpCase1Uint8 (TestBilinearInterpOpUint8 ):
232
187
def init_test_case (self ):
233
188
self .interp_method = 'bilinear'
234
189
self .input_shape = [2 , 3 , 128 , 64 ]
235
190
self .out_h = 120
236
191
self .out_w = 50
237
192
238
193
239
- class TestBilinearInterpCase2Uint8 (TestInterpolateOpUint8 ):
194
+ class TestBilinearInterpCase2Uint8 (TestBilinearInterpOpUint8 ):
240
195
def init_test_case (self ):
241
196
self .interp_method = 'bilinear'
242
197
self .input_shape = [4 , 1 , 7 , 8 ]
@@ -245,91 +200,5 @@ def init_test_case(self):
245
200
self .out_size = np .array ([6 , 15 ]).astype ("int32" )
246
201
247
202
248
- class TestNearestNeighborInterpCase1 (TestInterpolateOp ):
249
- def init_test_case (self ):
250
- self .interp_method = 'nearest'
251
- self .input_shape = [4 , 1 , 7 , 8 ]
252
- self .out_h = 1
253
- self .out_w = 1
254
-
255
-
256
- class TestNearestNeighborInterpCase2 (TestInterpolateOp ):
257
- def init_test_case (self ):
258
- self .interp_method = 'nearest'
259
- self .input_shape = [3 , 3 , 9 , 6 ]
260
- self .out_h = 12
261
- self .out_w = 12
262
-
263
-
264
- class TestNearestNeighborInterpCase3 (TestInterpolateOp ):
265
- def init_test_case (self ):
266
- self .interp_method = 'nearest'
267
- self .input_shape = [1 , 1 , 128 , 64 ]
268
- self .out_h = 64
269
- self .out_w = 128
270
-
271
-
272
- class TestNearestNeighborInterpCase4 (TestInterpolateOp ):
273
- def init_test_case (self ):
274
- self .interp_method = 'nearest'
275
- self .input_shape = [4 , 1 , 7 , 8 ]
276
- self .out_h = 1
277
- self .out_w = 1
278
- self .out_size = np .array ([2 , 2 ]).astype ("int32" )
279
-
280
-
281
- class TestNearestNeighborInterpCase5 (TestInterpolateOp ):
282
- def init_test_case (self ):
283
- self .interp_method = 'nearest'
284
- self .input_shape = [3 , 3 , 9 , 6 ]
285
- self .out_h = 12
286
- self .out_w = 12
287
- self .out_size = np .array ([11 , 11 ]).astype ("int32" )
288
-
289
-
290
- class TestNearestNeighborInterpCase6 (TestInterpolateOp ):
291
- def init_test_case (self ):
292
- self .interp_method = 'nearest'
293
- self .input_shape = [1 , 1 , 128 , 64 ]
294
- self .out_h = 64
295
- self .out_w = 128
296
- self .out_size = np .array ([65 , 129 ]).astype ("int32" )
297
-
298
-
299
- class TestNearestNeighborInterpActualShape (TestInterpolateOp ):
300
- def init_test_case (self ):
301
- self .interp_method = 'nearest'
302
- self .input_shape = [3 , 2 , 32 , 16 ]
303
- self .out_h = 64
304
- self .out_w = 32
305
- self .out_size = np .array ([66 , 40 ]).astype ("int32" )
306
-
307
-
308
- class TestNearestNeighborInterpBigScale (TestInterpolateOp ):
309
- def init_test_case (self ):
310
- self .interp_method = 'nearest'
311
- self .input_shape = [4 , 4 , 64 , 32 ]
312
- self .out_h = 100
313
- self .out_w = 50
314
- self .out_size = np .array ([101 , 51 ]).astype ('int32' )
315
-
316
-
317
- class TestNearestNeighborInterpCase1Uint8 (TestInterpolateOpUint8 ):
318
- def init_test_case (self ):
319
- self .interp_method = 'nearest'
320
- self .input_shape = [2 , 3 , 128 , 64 ]
321
- self .out_h = 120
322
- self .out_w = 50
323
-
324
-
325
- class TestNearestNeighborInterpCase2Uint8 (TestInterpolateOpUint8 ):
326
- def init_test_case (self ):
327
- self .interp_method = 'nearest'
328
- self .input_shape = [4 , 1 , 7 , 8 ]
329
- self .out_h = 5
330
- self .out_w = 13
331
- self .out_size = np .array ([6 , 15 ]).astype ("int32" )
332
-
333
-
334
203
if __name__ == "__main__" :
335
204
unittest .main ()
0 commit comments