@@ -1796,9 +1796,8 @@ def func(x, y):
1796
1796
y_val = np .array (9 , dtype = np .int32 )
1797
1797
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val , _INPUT1 : y_val })
1798
1798
1799
- @skip_caffe2_backend ("fails with schema error" )
1800
1799
@check_opset_min_version (7 , "batchnorm" )
1801
- def test_batchnorm (self ):
1800
+ def test_fused_batchnorm (self ):
1802
1801
x_shape = [1 , 28 , 28 , 2 ]
1803
1802
x_dtype = np .float32
1804
1803
scale_dtype = np .float32
@@ -1822,6 +1821,25 @@ def func(x):
1822
1821
return tf .identity (y , name = _TFOUTPUT )
1823
1822
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 )
1824
1823
1824
+ @check_opset_min_version (7 , "batchnorm" )
1825
+ @check_tf_min_version ("1.13" )
1826
+ def test_batchnorm (self ):
1827
+ x_shape = [1 , 128 , 128 , 2 ]
1828
+ x_dtype = np .float32
1829
+ scale_dtype = np .float32
1830
+ scale_shape = [2 ]
1831
+ x_val = np .random .random_sample (x_shape ).astype (x_dtype )
1832
+ scale_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
1833
+ offset_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
1834
+ mean_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
1835
+ var_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
1836
+ def func (x , mean , offset , var ):
1837
+ scale = tf .constant (scale_val , name = 'scale' )
1838
+ epsilon = 0.001
1839
+ y = tf .nn .batch_normalization (x , mean , var , offset , scale , epsilon )
1840
+ return tf .identity (y , name = _TFOUTPUT )
1841
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val , _INPUT1 : mean_val , _INPUT2 : offset_val , _INPUT3 : var_val })
1842
+
1825
1843
@skip_caffe2_backend ()
1826
1844
@check_opset_min_version (7 , "resize_nearest_neighbor" )
1827
1845
def test_resize_nearest_neighbor (self ):
0 commit comments