@@ -1898,6 +1898,28 @@ def test_const_fold_node_is_output(self):
1898
1898
self .run_transpose_compare (["res" ], {},
1899
1899
model_proto , remaining_transpose_num = 0 )
1900
1900
1901
+ def test_const_fold_concat (self ):
1902
+ shape = (6 , 4 )
1903
+ const_tensor = helper .make_tensor (name = 'const_tensor' , data_type = TensorProto .FLOAT , dims = shape ,
1904
+ vals = np .random .randn (* shape ).flatten ().astype (np .float32 ))
1905
+ const_tensor2 = helper .make_tensor (name = 'const_tensor2' , data_type = TensorProto .FLOAT , dims = shape ,
1906
+ vals = np .random .randn (* shape ).flatten ().astype (np .float32 ))
1907
+ node1 = helper .make_node ("Constant" , [], ["const" ], value = const_tensor )
1908
+ node2 = helper .make_node ("Constant" , [], ["const2" ], value = const_tensor2 )
1909
+ node3 = helper .make_node ("Concat" , ["const" , "const2" , "const" ], ["value1" ], axis = 1 )
1910
+ node4 = helper .make_node ("Add" , ["value1" , "inp" ], ["res" ])
1911
+
1912
+ graph = helper .make_graph (
1913
+ [node1 , node2 , node3 , node4 ],
1914
+ "test_const_fold_trans_with_const2" ,
1915
+ [helper .make_tensor_value_info ("inp" , TensorProto .FLOAT , [6 , 12 ])],
1916
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , [6 , 12 ])],
1917
+ )
1918
+
1919
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
1920
+ self .run_and_compare (["res" ], {"inp" : np .random .randn (6 , 12 ).astype (np .float32 )}, model_proto ,
1921
+ "Concat" , 0 )
1922
+
1901
1923
@check_opset_max_version (12 , "Squeeze/Unsqueeze changed in opset 13" )
1902
1924
def test_const_fold_unsqueeze_with_const (self ):
1903
1925
shape = (6 , 6 )
0 commit comments