@@ -1457,6 +1457,15 @@ def forward(self, x):
14571457
14581458save_data_and_model_multy_inputs ("einsum_2d" , einsum , mat1 , mat2 , export_params = True )
14591459
1460+ # 2d test case with ellipses
1461+ mat1 = torch .randn (4 , 5 )
1462+ mat2 = torch .randn (5 , 8 )
1463+ equation = "...ij, ...jk -> ...ik"
1464+ einsum = Einsum (equation )
1465+ output = einsum (mat1 , mat2 )
1466+
1467+ save_data_and_model_multy_inputs ("einsum_2d_ellipses" , einsum , mat1 , mat2 , export_params = True )
1468+
14601469# 3d test case
14611470mat1 = torch .ones (2 , 4 , 5 )
14621471mat2 = torch .ones (2 , 5 , 8 )
@@ -2519,6 +2528,25 @@ def forward(self, x):
25192528x = torch .randn (2 , 3 , 4 )
25202529save_data_and_model ("cumsum_3d_dim_2" , x , CumSum (dim = 2 ), version = 11 )
25212530
2531+ # test: CumSum exclusive layer should not be executed inplace
2532+ dims = h .make_node ("Constant" , inputs = [], outputs = ["dims1" ], name = "node-c1" ,
2533+ value = h .make_tensor (name = "c1v" , data_type = onnx .TensorProto .INT64 , dims = [], vals = np .asarray ([1 , ], dtype = np .int64 )))
2534+ one = h .make_node ("Constant" , inputs = [], outputs = ["one1" ], name = "node-c2" ,
2535+ value = h .make_tensor (name = "c2v" , data_type = onnx .TensorProto .FLOAT , dims = [], vals = np .asarray ([1 , ], dtype = np .float32 )))
2536+
2537+ mult = h .make_node ("Mul" , inputs = ["input1" , "one1" ], outputs = ["mul_output1" ], name = "node-m1" )
2538+ cumsum = h .make_node ("CumSum" , inputs = ["mul_output1" , "dims1" ], outputs = ["cumsum_output1" ], name = "node-r1" , exclusive = 1 )
2539+
2540+ graph = h .make_graph ([dims , one , mult , cumsum ], "graph123" ,
2541+ [h .make_tensor_value_info ("input1" , onnx .TensorProto .FLOAT , [1 , 3 , 1 , 1 ]),],
2542+ [h .make_tensor_value_info ("cumsum_output1" , onnx .TensorProto .FLOAT , [1 , 3 , 1 , 1 ])])
2543+ cumsum_model = h .make_model (graph , producer_name = "model_cumsum" )
2544+ onnx .checker .check_model (cumsum_model )
2545+
2546+ input_np = np .array ([1 , 2 , 3 ], dtype = np .float32 ).reshape (1 , 3 , 1 , 1 )
2547+ output_np = np .array ([0 , 1 , 3 ], dtype = np .float32 ).reshape (1 , 3 , 1 , 1 )
2548+ save_data_and_onnx_model ("cumsum_exclusive_inplace" , input_np , output_np , cumsum_model )
2549+
25222550# where layer
25232551class Where (nn .Module ):
25242552 def __init__ (self , * args , ** kwargs ):
0 commit comments