@@ -95,6 +95,7 @@ def run_onnxmsrtnext(onnx_graph, inputs, output_names, test_name):
95
95
"""Run test against msrt-next backend."""
96
96
import lotus
97
97
model_path = os .path .join (TMPPATH , test_name + ".pb" )
98
+ # print(model_path)
98
99
with open (model_path , "wb" ) as f :
99
100
f .write (onnx_graph .SerializeToString ())
100
101
m = lotus .InferenceSession (model_path )
@@ -801,7 +802,6 @@ def test_cast(self):
801
802
self .assertAllClose (expected , actual )
802
803
803
804
def test_onehot0 (self ):
804
- # no such op in onnx
805
805
x_val = np .array ([0 , 1 , 2 ], dtype = np .int32 )
806
806
depth = 5
807
807
for axis in [- 1 , 0 , 1 ]:
@@ -814,7 +814,7 @@ def test_onehot0(self):
814
814
815
815
@unittest .skip
816
816
def test_onehot1 (self ):
817
- # no such op in onnx
817
+ # only rank 1 is currently implemented
818
818
x_val = np .array ([[0 , 2 ], [1 , - 1 ]], dtype = np .int32 )
819
819
depth = 3
820
820
x = tf .placeholder (tf .int32 , x_val .shape , name = _TFINPUT )
@@ -824,7 +824,6 @@ def test_onehot1(self):
824
824
self .assertAllClose (expected , actual )
825
825
826
826
def test_onehot2 (self ):
827
- # no such op in onnx
828
827
x_val = np .array ([0 , 1 , 2 , 1 , 2 , 0 , 1 , 2 , 1 , 2 ], dtype = np .int32 )
829
828
depth = 20
830
829
x = tf .placeholder (tf .int32 , x_val .shape , name = _TFINPUT )
@@ -924,6 +923,60 @@ def test_strided_slice2(self):
924
923
actual , expected = self ._run (output , {x : x_val }, {_INPUT : x_val })
925
924
self .assertAllClose (expected , actual )
926
925
926
+ @unittest .skip
927
+ def test_strided_slice3 (self ):
928
+ x_val = np .arange (3 * 2 * 3 ).astype ("float32" ).reshape (3 , 2 , 3 )
929
+ x = tf .placeholder (tf .float32 , x_val .shape , name = _TFINPUT )
930
+ x_ = x [1 :]
931
+ output = tf .identity (x_ , name = _TFOUTPUT )
932
+ actual , expected = self ._run (output , {x : x_val }, {_INPUT : x_val })
933
+ self .assertAllClose (expected , actual )
934
+
935
+ @unittest .skip
936
+ def test_strided_slice4 (self ):
937
+ x_val = np .arange (3 * 2 * 3 ).astype ("float32" ).reshape (3 , 2 , 3 )
938
+ x = tf .placeholder (tf .float32 , x_val .shape , name = _TFINPUT )
939
+ x_ = x [:2 ]
940
+ output = tf .identity (x_ , name = _TFOUTPUT )
941
+ actual , expected = self ._run (output , {x : x_val }, {_INPUT : x_val })
942
+ self .assertAllClose (expected , actual )
943
+
944
+ @unittest .skip
945
+ def test_strided_slice5 (self ):
946
+ x_val = np .arange (3 * 2 * 3 ).astype ("float32" ).reshape (3 , 2 , 3 )
947
+ x = tf .placeholder (tf .float32 , x_val .shape , name = _TFINPUT )
948
+ x_ = x [:2 , 0 :1 , 1 :]
949
+ output = tf .identity (x_ , name = _TFOUTPUT )
950
+ actual , expected = self ._run (output , {x : x_val }, {_INPUT : x_val })
951
+ self .assertAllClose (expected , actual )
952
+
953
+ @unittest .skipIf (BACKEND in ["caffe2" , "onnxmsrt" ], "fails with schema error" )
954
+ def test_batchnorm (self ):
955
+ x_shape = [1 , 28 , 28 , 2 ]
956
+ x_dtype = np .float32
957
+ scale_dtype = np .float32
958
+ scale_shape = [2 ]
959
+ # only nhwc is support on cpu for tensorflow
960
+ data_format = "NHWC"
961
+ x_val = np .random .random_sample (x_shape ).astype (x_dtype )
962
+ scale_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
963
+ offset_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
964
+ mean_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
965
+ var_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
966
+
967
+ x = tf .placeholder (tf .float32 , x_val .shape , name = _TFINPUT )
968
+ scale = tf .constant (scale_val , name = 'scale' )
969
+ offset = tf .constant (offset_val , name = 'offset' )
970
+ mean = tf .constant (mean_val , name = 'mean' )
971
+ var = tf .constant (var_val , name = 'variance' )
972
+ epsilon = 0.001
973
+ y , _ , _ = tf .nn .fused_batch_norm (
974
+ x , scale , offset , mean = mean , variance = var ,
975
+ epsilon = epsilon , data_format = data_format , is_training = False )
976
+ output = tf .identity (y , name = _TFOUTPUT )
977
+ actual , expected = self ._run (output , {x : x_val }, {_INPUT : x_val })
978
+ self .assertAllClose (expected , actual , rtol = 1e-04 )
979
+
927
980
@unittest .skipIf (BACKEND in ["caffe2" , "onnxmsrt" ], "not correctly supported" )
928
981
def test_resize_nearest_neighbor (self ):
929
982
x_shape = [1 , 15 , 20 , 2 ]
@@ -964,10 +1017,13 @@ def test_fill(self):
964
1017
parser .add_argument ('--backend' , default = 'caffe2' ,
965
1018
choices = ["caffe2" , "onnxmsrt" , "onnxmsrtnext" , "onnx-tensorflow" ],
966
1019
help = "backend to test against" )
1020
+ parser .add_argument ('--opset' , default = OPSET ,
1021
+ help = "opset to test against" )
967
1022
parser .add_argument ('unittest_args' , nargs = '*' )
968
1023
969
1024
args = parser .parse_args ()
970
1025
BACKEND = args .backend
1026
+ OPSET = args .opset
971
1027
# Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone)
972
1028
sys .argv [1 :] = args .unittest_args
973
1029
unittest .main ()
0 commit comments