@@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set):
31
31
return backward_op
32
32
33
33
34
+ def _reference_testing (x , scale , offset , mean , var , epsilon , data_format ):
35
+ x_shape = x .shape
36
+ if len (x_shape ) == 2 :
37
+ if data_format == "NCHW" :
38
+ x = np .reshape (x , (x .shape [0 ], x .shape [1 ], 1 , 1 ))
39
+ else :
40
+ x = np .reshape (x , (x .shape [0 ], 1 , 1 , x .shape [1 ]))
41
+
42
+ if data_format == "NCHW" :
43
+ n , c , h , w = x .shape
44
+ mean_tile = np .reshape (mean , (1 , c , 1 , 1 ))
45
+ mean_tile = np .tile (mean_tile , (n , 1 , h , w ))
46
+ var_tile = np .reshape (var , (1 , c , 1 , 1 ))
47
+ var_tile = np .tile (var_tile , (n , 1 , h , w ))
48
+ normalized = (x - mean_tile ) / np .sqrt (var_tile + epsilon )
49
+ scale_tile = np .reshape (scale , (1 , c , 1 , 1 ))
50
+ scale_tile = np .tile (scale_tile , (n , 1 , h , w ))
51
+ offset_tile = np .reshape (offset , (1 , c , 1 , 1 ))
52
+ offset_tile = np .reshape (offset_tile , (1 , c , 1 , 1 ))
53
+ y = normalized * scale_tile + offset_tile
54
+ elif data_format == "NHWC" :
55
+ normalized = (x - mean ) / np .sqrt (var + epsilon )
56
+ y = normalized * scale + offset
57
+ else :
58
+ raise ValueError ("Unknown data order." )
59
+
60
+ if len (x_shape ) == 2 :
61
+ y = np .reshape (y , x_shape )
62
+ return y
63
+
64
+
34
65
def _reference_training (x , scale , offset , epsilon , data_format ):
35
66
x_shape = x .shape
36
67
if len (x_shape ) == 2 :
@@ -155,7 +186,43 @@ def __set_tensor__(name, data=None):
155
186
__set_tensor__ (output , data )
156
187
157
188
158
- class TestBatchNormOp (OpTest ):
189
+ class TestBatchNormOpInference (OpTest ):
190
+ def setUp (self ):
191
+ self .dtype = np .float32
192
+
193
+ def test_python (self ):
194
+ data_format = "NHWC"
195
+ epsilon = 0.00001
196
+
197
+ n , h , w , c = 2 , 3 , 4 , 5
198
+ x_shape = [n , h , w , c ]
199
+ scale_shape = [c ]
200
+
201
+ x_val = np .random .random_sample (x_shape ).astype (self .dtype )
202
+ scale_val = np .random .random_sample (scale_shape ).astype (self .dtype )
203
+ bias_val = np .random .random_sample (scale_shape ).astype (self .dtype )
204
+
205
+ mean = np .zeros (scale_shape ).astype (self .dtype )
206
+ variance = np .ones (scale_shape ).astype (self .dtype )
207
+
208
+ # run forward
209
+ y_out = _reference_testing (x_val , scale_val , bias_val , mean , variance ,
210
+ epsilon , "NHWC" )
211
+
212
+ # running N, C, H, W case
213
+ # should produce the same results
214
+ x_shape2 = [n , c , h , w ]
215
+ x_val2 = np .transpose (x_val , (0 , 3 , 1 , 2 ))
216
+ y_out2 = _reference_testing (x_val2 , scale_val , bias_val , mean , variance ,
217
+ epsilon , "NCHW" )
218
+
219
+ # transfer (N, C, H, W) back to (N, H, W, C)
220
+ y_out2_trans = np .transpose (y_out2 , (0 , 2 , 3 , 1 ))
221
+ self .__assert_close (y_out , y_out2_trans , "inference output" )
222
+ print 'python: NHWC, NCHW, inference checking passed'
223
+
224
+
225
+ class TestBatchNormOpTraining (OpTest ):
159
226
def __assert_close (self , tensor , np_array , msg , atol = 1e-4 ):
160
227
self .assertTrue (np .allclose (np .array (tensor ), np_array , atol = atol ), msg )
161
228
0 commit comments