@@ -188,14 +188,27 @@ def __set_tensor__(name, data=None):
188
188
189
189
class TestBatchNormOpInference (OpTest ):
190
190
def setUp (self ):
191
+ self .op_type = "conv2d"
192
+ self .is_test = True
191
193
self .dtype = np .float32
194
+ self .data_layout = "NCHW"
195
+ init_dtype ()
196
+ init_data_layout ()
197
+ init_test_case ()
192
198
193
- def test_python (self ):
194
- data_format = "NHWC"
195
199
epsilon = 0.00001
196
-
197
- n , h , w , c = 2 , 3 , 4 , 5
198
- x_shape = [n , h , w , c ]
200
+ shape = self .shape
201
+ if len (shape ) == 2 :
202
+ x_shape = shape
203
+ c = x_shape [1 ]
204
+ else :
205
+ n , h , w , c = shape [0 ], shape [1 ], shape [2 ], shape [3 ]
206
+ if self .data_layout == "NHWC" :
207
+ x_shape = [n , h , w , c ]
208
+ elif self .data_layout == "NCHW" :
209
+ x_shape = [n , c , h , w ]
210
+ else :
211
+ raise ValueError ("Unknown data layout." )
199
212
scale_shape = [c ]
200
213
201
214
x_val = np .random .random_sample (x_shape ).astype (self .dtype )
@@ -205,7 +218,64 @@ def test_python(self):
205
218
mean = np .zeros (scale_shape ).astype (self .dtype )
206
219
variance = np .ones (scale_shape ).astype (self .dtype )
207
220
208
- # run forward
221
+ saved_mean = np .zeros (scale_shape ).astype (self .dtype )
222
+ saved_variance = np .ones (scale_shape ).astype (self .dtype )
223
+
224
+ y_out = _reference_testing (x_val , scale_val , bias_val , mean , variance ,
225
+ epsilon , self .data_layout ).astype (self .dtype )
226
+
227
+ self .inputs = {
228
+ 'X' : OpTest .np_dtype_to_fluid_dtype (x_val ),
229
+ 'Scale' : OpTest .np_dtype_to_fluid_dtype (scale_val ),
230
+ 'Bias' : OpTest .np_dtype_to_fluid_dtype (bias_val ),
231
+ 'Mean' : OpTest .np_dtype_to_fluid_dtype (mean ),
232
+ 'Variance' : OpTest .np_dtype_to_fluid_dtype (variance )
233
+ }
234
+ self .attrs = {
235
+ 'is_test' : self .is_test ,
236
+ 'epsilon' : epsilon ,
237
+ 'data_layout' : self .data_layout
238
+ }
239
+ self .outputs = {
240
+ 'Y' : y_out ,
241
+ 'MeanOut' : mean ,
242
+ 'VarianceOut' : variance ,
243
+ 'SavedMean' : saved_mean ,
244
+ 'SavedVariance' : saved_variance
245
+ }
246
+
247
+ def test_check_output (self ):
248
+ self .check_output ()
249
+
250
+ def init_dtype (self ):
251
+ pass
252
+
253
+ def init_data_layout (self ):
254
+ pass
255
+
256
+ def init_test_case (self ):
257
+ self .shape = [2 , 3 , 4 , 5 ]
258
+
259
+
260
+ class TestBatchNormOpTraining (OpTest ):
261
+ def __assert_close (self , tensor , np_array , msg , atol = 1e-4 ):
262
+ self .assertTrue (np .allclose (np .array (tensor ), np_array , atol = atol ), msg )
263
+
264
+ def test_python_testing (self ):
265
+ data_format = "NHWC"
266
+ epsilon = 0.00001
267
+
268
+ n , h , w , c = 2 , 3 , 4 , 5
269
+ x_shape = [n , h , w , c ]
270
+ scale_shape = [c ]
271
+
272
+ x_val = np .random .random_sample (x_shape ).astype (np .float32 )
273
+ scale_val = np .random .random_sample (scale_shape ).astype (np .float32 )
274
+ bias_val = np .random .random_sample (scale_shape ).astype (np .float32 )
275
+
276
+ mean = np .zeros (scale_shape ).astype (np .float32 )
277
+ variance = np .ones (scale_shape ).astype (np .float32 )
278
+
209
279
y_out = _reference_testing (x_val , scale_val , bias_val , mean , variance ,
210
280
epsilon , "NHWC" )
211
281
@@ -218,15 +288,11 @@ def test_python(self):
218
288
219
289
# transfer (N, C, H, W) back to (N, H, W, C)
220
290
y_out2_trans = np .transpose (y_out2 , (0 , 2 , 3 , 1 ))
221
- self .__assert_close (y_out , y_out2_trans , "inference output" )
291
+ self .__assert_close (y_out , y_out2_trans ,
292
+ "inference outputs of two formats have differences" )
222
293
print 'python: NHWC, NCHW, inference checking passed'
223
294
224
-
225
- class TestBatchNormOpTraining (OpTest ):
226
- def __assert_close (self , tensor , np_array , msg , atol = 1e-4 ):
227
- self .assertTrue (np .allclose (np .array (tensor ), np_array , atol = atol ), msg )
228
-
229
- def test_python (self ):
295
+ def test_python_training (self ):
230
296
data_format = "NHWC"
231
297
epsilon = 0.00001
232
298
momentum = 0.9
@@ -264,7 +330,7 @@ def test_python(self):
264
330
265
331
# transfer (N, C, H, W) back to (N, H, W, C)
266
332
y_out2_trans = np .transpose (y_out2 , (0 , 2 , 3 , 1 ))
267
- self .__assert_close (y_out , y_out2_trans , "batch variance " )
333
+ self .__assert_close (y_out , y_out2_trans , "batch output " )
268
334
print 'python: NHWC, NCHW, forward checking passed'
269
335
270
336
# test backward now
0 commit comments