Skip to content

Commit 41b51c2

Browse files
authored
Merge pull request #816 from onnx/gs/bn
more batchnorm ut
2 parents 2ffc9d5 + 5757564 commit 41b51c2

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

tests/test_backend.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,9 +1796,8 @@ def func(x, y):
17961796
y_val = np.array(9, dtype=np.int32)
17971797
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
17981798

1799-
@skip_caffe2_backend("fails with schema error")
18001799
@check_opset_min_version(7, "batchnorm")
1801-
def test_batchnorm(self):
1800+
def test_fused_batchnorm(self):
18021801
x_shape = [1, 28, 28, 2]
18031802
x_dtype = np.float32
18041803
scale_dtype = np.float32
@@ -1822,6 +1821,25 @@ def func(x):
18221821
return tf.identity(y, name=_TFOUTPUT)
18231822
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04)
18241823

1824+
@check_opset_min_version(7, "batchnorm")
1825+
@check_tf_min_version("1.13")
1826+
def test_batchnorm(self):
1827+
x_shape = [1, 128, 128, 2]
1828+
x_dtype = np.float32
1829+
scale_dtype = np.float32
1830+
scale_shape = [2]
1831+
x_val = np.random.random_sample(x_shape).astype(x_dtype)
1832+
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
1833+
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
1834+
mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
1835+
var_val = np.random.random_sample(scale_shape).astype(scale_dtype)
1836+
def func(x, mean, offset, var):
1837+
scale = tf.constant(scale_val, name='scale')
1838+
epsilon = 0.001
1839+
y = tf.nn.batch_normalization(x, mean, var, offset, scale, epsilon)
1840+
return tf.identity(y, name=_TFOUTPUT)
1841+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: mean_val, _INPUT2: offset_val, _INPUT3: var_val})
1842+
18251843
@skip_caffe2_backend()
18261844
@check_opset_min_version(7, "resize_nearest_neighbor")
18271845
def test_resize_nearest_neighbor(self):

0 commit comments

Comments
 (0)