@@ -187,74 +187,99 @@ def __set_tensor__(name, data=None):
187
187
188
188
189
189
class TestBatchNormOpInference (OpTest ):
190
- def setUp (self ):
191
- self .op_type = "conv2d"
192
- self .is_test = True
193
- self .dtype = np .float32
194
- self .data_layout = "NCHW"
195
- init_dtype ()
196
- init_data_layout ()
197
- init_test_case ()
190
+ def __assert_close (self , tensor , np_array , msg , atol = 1e-4 ):
191
+ self .assertTrue (np .allclose (np .array (tensor ), np_array , atol = atol ), msg )
198
192
199
- epsilon = 0.00001
200
- shape = self .shape
201
- if len (shape ) == 2 :
202
- x_shape = shape
203
- c = x_shape [1 ]
204
- else :
205
- n , h , w , c = shape [0 ], shape [1 ], shape [2 ], shape [3 ]
206
- if self .data_layout == "NHWC" :
207
- x_shape = [n , h , w , c ]
208
- elif self .data_layout == "NCHW" :
209
- x_shape = [n , c , h , w ]
193
+ def test_inference (self ):
194
+ def test_with_place (place , data_layout , dtype , shape ):
195
+ epsilon = 0.00001
196
+ if len (shape ) == 2 :
197
+ x_shape = shape
198
+ c = x_shape [1 ]
210
199
else :
211
- raise ValueError ("Unknown data layout." )
212
- scale_shape = [c ]
200
+ n , h , w , c = shape [0 ], shape [1 ], shape [2 ], shape [3 ]
201
+ if data_layout == "NHWC" :
202
+ x_shape = [n , h , w , c ]
203
+ elif data_layout == "NCHW" :
204
+ x_shape = [n , c , h , w ]
205
+ else :
206
+ raise ValueError ("Unknown data layout." )
207
+ scale_shape = [c ]
213
208
214
- x_val = np .random .random_sample (x_shape ).astype (self . dtype )
215
- scale_val = np .random .random_sample (scale_shape ).astype (self . dtype )
216
- bias_val = np .random .random_sample (scale_shape ).astype (self . dtype )
209
+ x_val = np .random .random_sample (x_shape ).astype (dtype )
210
+ scale_val = np .random .random_sample (scale_shape ).astype (dtype )
211
+ bias_val = np .random .random_sample (scale_shape ).astype (dtype )
217
212
218
- mean = np .zeros (scale_shape ).astype (self . dtype )
219
- variance = np .ones (scale_shape ).astype (self . dtype )
213
+ mean = np .zeros (scale_shape ).astype (dtype )
214
+ variance = np .ones (scale_shape ).astype (dtype )
220
215
221
- saved_mean = np .zeros (scale_shape ).astype (self .dtype )
222
- saved_variance = np .ones (scale_shape ).astype (self .dtype )
216
+ y_out = _reference_testing (x_val , scale_val , bias_val , mean ,
217
+ variance , epsilon ,
218
+ data_layout ).astype (dtype )
223
219
224
- y_out = _reference_testing (x_val , scale_val , bias_val , mean , variance ,
225
- epsilon , self .data_layout ).astype (self .dtype )
226
-
227
- self .inputs = {
228
- 'X' : OpTest .np_dtype_to_fluid_dtype (x_val ),
229
- 'Scale' : OpTest .np_dtype_to_fluid_dtype (scale_val ),
230
- 'Bias' : OpTest .np_dtype_to_fluid_dtype (bias_val ),
231
- 'Mean' : OpTest .np_dtype_to_fluid_dtype (mean ),
232
- 'Variance' : OpTest .np_dtype_to_fluid_dtype (variance )
233
- }
234
- self .attrs = {
235
- 'is_test' : self .is_test ,
236
- 'epsilon' : epsilon ,
237
- 'data_layout' : self .data_layout
238
- }
239
- self .outputs = {
240
- 'Y' : y_out ,
241
- 'MeanOut' : mean ,
242
- 'VarianceOut' : variance ,
243
- 'SavedMean' : saved_mean ,
244
- 'SavedVariance' : saved_variance
245
- }
246
-
247
- def test_check_output (self ):
248
- self .check_output ()
249
-
250
- def init_dtype (self ):
251
- pass
252
-
253
- def init_data_layout (self ):
254
- pass
255
-
256
- def init_test_case (self ):
257
- self .shape = [2 , 3 , 4 , 5 ]
220
+ scope = core .Scope ()
221
+
222
+ # create input
223
+ x_tensor = create_or_get_tensor (
224
+ scope , "x_val" , OpTest .np_dtype_to_fluid_dtype (x_val ), place )
225
+ scale_tensor = create_or_get_tensor (
226
+ scope , "scale_val" ,
227
+ OpTest .np_dtype_to_fluid_dtype (scale_val ), place )
228
+ bias_tensor = create_or_get_tensor (
229
+ scope , "bias_val" ,
230
+ OpTest .np_dtype_to_fluid_dtype (bias_val ), place )
231
+ mean_tensor = create_or_get_tensor (
232
+ scope , "mean" , OpTest .np_dtype_to_fluid_dtype (mean ), place )
233
+ variance_tensor = create_or_get_tensor (
234
+ scope , "variance" ,
235
+ OpTest .np_dtype_to_fluid_dtype (variance ), place )
236
+
237
+ # create output
238
+ y_tensor = create_or_get_tensor (scope , "y_out" , None , place )
239
+ saved_mean_tensor = create_or_get_tensor (scope , "saved_mean" , None ,
240
+ place )
241
+ saved_variance_tensor = create_or_get_tensor (
242
+ scope , "saved_variance" , None , place )
243
+ mean_out_tensor = mean_tensor
244
+ variance_out_tensor = variance_tensor
245
+
246
+ batch_norm_op = Operator (
247
+ "batch_norm" ,
248
+ # inputs
249
+ X = "x_val" ,
250
+ Scale = "scale_val" ,
251
+ Bias = "bias_val" ,
252
+ Mean = "mean" ,
253
+ Variance = "variance" ,
254
+ # outputs
255
+ Y = "y_out" ,
256
+ MeanOut = "mean" ,
257
+ VarianceOut = "variance" ,
258
+ SavedMean = "saved_mean" ,
259
+ SavedVariance = "saved_variance" ,
260
+ # attrs
261
+ is_test = True ,
262
+ data_layout = data_layout ,
263
+ epsilon = epsilon )
264
+
265
+ batch_norm_op .run (scope , place )
266
+
267
+ # check inference result
268
+ self .__assert_close (
269
+ y_tensor , y_out , "inference output are different at " +
270
+ str (place ) + ", " + data_layout + ", " + str (np .dtype (dtype )))
271
+
272
+ places = [core .CPUPlace ()]
273
+ if core .is_compiled_with_cuda () and core .op_support_gpu ("batch_norm" ):
274
+ place = core .CUDAPlace (0 )
275
+ if self .dtype != np .float16 or core .is_float16_supported (place ):
276
+ places .append (place )
277
+
278
+ for place in places :
279
+ for data_format in ["NCHW" , "NHWC" ]:
280
+ for dtype in [np .float32 , np .float16 ]:
281
+ test_with_place (place , data_format , dtype , [2 , 3 , 4 , 5 ])
282
+ test_with_place (place , data_format , dtype , [2 , 3 ])
258
283
259
284
260
285
class TestBatchNormOpTraining (OpTest ):
@@ -288,8 +313,7 @@ def test_python_testing(self):
288
313
289
314
# transfer (N, C, H, W) back to (N, H, W, C)
290
315
y_out2_trans = np .transpose (y_out2 , (0 , 2 , 3 , 1 ))
291
- self .__assert_close (y_out , y_out2_trans ,
292
- "inference outputs of two formats have differences" )
316
+ self .__assert_close (y_out , y_out2_trans , "inference output" )
293
317
print 'python: NHWC, NCHW, inference checking passed'
294
318
295
319
def test_python_training (self ):
0 commit comments