@@ -1958,6 +1958,87 @@ def test_const_fold_cast_with_const(self):
1958
1958
self .run_and_compare (["res" ], {"X" : np .random .randn (* shape ).astype (np .int64 )}, model_proto ,
1959
1959
"Cast" , 0 )
1960
1960
1961
+ def test_const_fold_split (self ):
1962
+ shape = (2 , 6 , 1 )
1963
+ const_tensor = helper .make_tensor (name = 'const_tensor' , data_type = TensorProto .FLOAT , dims = shape ,
1964
+ vals = np .random .randn (2 , 6 , 1 ).flatten ().astype (np .float32 ))
1965
+ node0 = helper .make_node ("Constant" , [], ["const" ], value = const_tensor )
1966
+ node1 = helper .make_node ("Split" , ["const" ], ["out1" , "out2" , "out3" ], axis = 1 )
1967
+ node2 = helper .make_node ("Sum" , ["inp" , "out1" , "out2" , "out3" ], ["out4" ])
1968
+
1969
+ graph = helper .make_graph (
1970
+ [node0 , node1 , node2 ],
1971
+ "test_const_fold_split" ,
1972
+ [helper .make_tensor_value_info ("inp" , TensorProto .FLOAT , (2 , 2 , 1 ))],
1973
+ [helper .make_tensor_value_info ("out4" , TensorProto .FLOAT , (2 , 2 , 1 ))],
1974
+ )
1975
+
1976
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
1977
+ self .run_and_compare (["out4" ], {"inp" : np .random .randn (2 , 2 , 1 ).astype (np .float32 )}, model_proto ,
1978
+ "Split" , 0 )
1979
+
1980
+ def test_const_fold_split_one (self ):
1981
+ shape = (2 , 6 , 1 )
1982
+ const_tensor = helper .make_tensor (name = 'const_tensor' , data_type = TensorProto .FLOAT , dims = shape ,
1983
+ vals = np .random .randn (2 , 6 , 1 ).flatten ().astype (np .float32 ))
1984
+ node0 = helper .make_node ("Constant" , [], ["const" ], value = const_tensor )
1985
+ node1 = helper .make_node ("Split" , ["const" ], ["out1" ], axis = 1 )
1986
+ node2 = helper .make_node ("Sum" , ["inp" , "out1" ], ["out4" ])
1987
+
1988
+ graph = helper .make_graph (
1989
+ [node0 , node1 , node2 ],
1990
+ "test_const_fold_split" ,
1991
+ [helper .make_tensor_value_info ("inp" , TensorProto .FLOAT , (2 , 6 , 1 ))],
1992
+ [helper .make_tensor_value_info ("out4" , TensorProto .FLOAT , (2 , 6 , 1 ))],
1993
+ )
1994
+
1995
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
1996
+ self .run_and_compare (["out4" ], {"inp" : np .random .randn (2 , 6 , 1 ).astype (np .float32 )}, model_proto ,
1997
+ "Split" , 0 )
1998
+
1999
+ @check_opset_min_version (13 , "Split changed in opset 13" )
2000
+ def test_const_fold_split_const_splits_13 (self ):
2001
+ shape = (2 , 6 , 1 )
2002
+ const_tensor = helper .make_tensor (name = 'const_tensor' , data_type = TensorProto .FLOAT , dims = shape ,
2003
+ vals = np .random .randn (2 , 6 , 1 ).flatten ().astype (np .float32 ))
2004
+ node0 = helper .make_node ("Constant" , [], ["const" ], value = const_tensor )
2005
+ const_splits = helper .make_tensor (name = 'const_tensor' , data_type = TensorProto .INT64 , dims = [3 ],
2006
+ vals = np .array ([1 , 3 , 2 ], np .int64 ))
2007
+ node1 = helper .make_node ("Constant" , [], ["splits" ], value = const_splits )
2008
+ node2 = helper .make_node ("Split" , ["const" , "splits" ], ["out1" , "out2" , "out3" ], axis = 1 )
2009
+ node3 = helper .make_node ("Sum" , ["inp" , "out2" ], ["out4" ])
2010
+
2011
+ graph = helper .make_graph (
2012
+ [node0 , node1 , node2 , node3 ],
2013
+ "test_const_fold_split" ,
2014
+ [helper .make_tensor_value_info ("inp" , TensorProto .FLOAT , (2 , 3 , 1 ))],
2015
+ [helper .make_tensor_value_info ("out4" , TensorProto .FLOAT , (2 , 3 , 1 ))],
2016
+ )
2017
+
2018
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
2019
+ self .run_and_compare (["out4" ], {"inp" : np .random .randn (2 , 3 , 1 ).astype (np .float32 )}, model_proto ,
2020
+ "Split" , 0 )
2021
+
2022
+ @check_opset_max_version (12 , "Split changed in opset 13" )
2023
+ def test_const_fold_split_const_splits (self ):
2024
+ shape = (2 , 6 , 1 )
2025
+ const_tensor = helper .make_tensor (name = 'const_tensor' , data_type = TensorProto .FLOAT , dims = shape ,
2026
+ vals = np .random .randn (2 , 6 , 1 ).flatten ().astype (np .float32 ))
2027
+ node0 = helper .make_node ("Constant" , [], ["const" ], value = const_tensor )
2028
+ node2 = helper .make_node ("Split" , ["const" ], ["out1" , "out2" , "out3" ], axis = 1 , split = [1 , 3 , 2 ])
2029
+ node3 = helper .make_node ("Sum" , ["inp" , "out2" ], ["out4" ])
2030
+
2031
+ graph = helper .make_graph (
2032
+ [node0 , node2 , node3 ],
2033
+ "test_const_fold_split" ,
2034
+ [helper .make_tensor_value_info ("inp" , TensorProto .FLOAT , (2 , 3 , 1 ))],
2035
+ [helper .make_tensor_value_info ("out4" , TensorProto .FLOAT , (2 , 3 , 1 ))],
2036
+ )
2037
+
2038
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
2039
+ self .run_and_compare (["out4" ], {"inp" : np .random .randn (2 , 3 , 1 ).astype (np .float32 )}, model_proto ,
2040
+ "Split" , 0 )
2041
+
1961
2042
# Const Fold Optimizer Tests End
1962
2043
1963
2044
# Const Dequantize Optimizer Tests Start
0 commit comments