@@ -1952,6 +1952,78 @@ return %2 : tensor<1x12x4xf32>
19521952
19531953}
19541954
1955+ // -----
1956+ func.func @test_split_relu_movement (%arg0: tensor <1 x8 x2 xf32 >) -> (tensor <1 x2 x2 xf32 >, tensor <1 x3 x2 xf32 >, tensor <1 x3 x2 xf32 >) {
1957+ %cst = onnx.Constant dense <[2 , 3 , 3 ]> : tensor <3 xi64 >
1958+ %0:3 = " onnx.Split" (%arg0 , %cst ) {axis = 1 : si64 } : (tensor <1 x8 x2 xf32 >, tensor <3 xi64 >) -> (tensor <1 x2 x2 xf32 >, tensor <1 x3 x2 xf32 >, tensor <1 x3 x2 xf32 >)
1959+ %1 = " onnx.Relu" (%0#0 ) {onnx_node_name = " onnx.Relu_1" } : (tensor <1 x2 x2 xf32 >) -> tensor <1 x2 x2 xf32 >
1960+ %2 = " onnx.Relu" (%0#1 ) {onnx_node_name = " onnx.Relu_2" } : (tensor <1 x3 x2 xf32 >) -> tensor <1 x3 x2 xf32 >
1961+ %3 = " onnx.Relu" (%0#2 ) {onnx_node_name = " onnx.Relu_3" } : (tensor <1 x3 x2 xf32 >) -> tensor <1 x3 x2 xf32 >
1962+ onnx.Return %1 , %2 , %3 : tensor <1 x2 x2 xf32 >, tensor <1 x3 x2 xf32 >, tensor <1 x3 x2 xf32 >
1963+ }
1964+ // CHECK-LABEL: func.func @test_split_relu_movement
1965+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
1966+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
1967+ // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Relu"([[PARAM_0_]]) {onnx_node_name = "onnx.Relu_1"} : (tensor<1x8x2xf32>) -> tensor<1x8x2xf32>
1968+ // CHECK: [[VAR_2_:%.+]]:3 = "onnx.Split"([[VAR_1_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
1969+ // CHECK: onnx.Return [[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2 : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
1970+ // CHECK: }
1971+
1972+ // -----
1973+ func.func @test_split_relu_movement_not_all_equal (%arg0: tensor <1 x8 x2 xf32 >) -> (tensor <1 x2 x2 xf32 >, tensor <1 x3 x2 xf32 >, tensor <1 x3 x2 xf32 >) {
1974+ %cst = onnx.Constant dense <[2 , 3 , 3 ]> : tensor <3 xi64 >
1975+ %0:3 = " onnx.Split" (%arg0 , %cst ) {axis = 1 : si64 } : (tensor <1 x8 x2 xf32 >, tensor <3 xi64 >) -> (tensor <1 x2 x2 xf32 >, tensor <1 x3 x2 xf32 >, tensor <1 x3 x2 xf32 >)
1976+ %1 = " onnx.Relu" (%0#0 ) {onnx_node_name = " onnx.Relu_1" } : (tensor <1 x2 x2 xf32 >) -> tensor <1 x2 x2 xf32 >
1977+ %2 = " onnx.LeakyRelu" (%0#1 ) {onnx_node_name = " onnx.Relu_2" } : (tensor <1 x3 x2 xf32 >) -> tensor <1 x3 x2 xf32 >
1978+ %3 = " onnx.Relu" (%0#2 ) {onnx_node_name = " onnx.Relu_3" } : (tensor <1 x3 x2 xf32 >) -> tensor <1 x3 x2 xf32 >
1979+ onnx.Return %1 , %2 , %3 : tensor <1 x2 x2 xf32 >, tensor <1 x3 x2 xf32 >, tensor <1 x3 x2 xf32 >
1980+ }
1981+ // CHECK-LABEL: func.func @test_split_relu_movement_not_all_equal
1982+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
1983+ // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
1984+ // CHECK: [[VAR_1_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
1985+ // CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Relu"([[VAR_1_]]#0) {onnx_node_name = "onnx.Relu_1"} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
1986+ // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.LeakyRelu"([[VAR_1_]]#1) {alpha = 0.00999999977 : f32, onnx_node_name = "onnx.Relu_2"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1987+ // CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Relu"([[VAR_1_]]#2) {onnx_node_name = "onnx.Relu_3"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1988+ // CHECK: onnx.Return [[VAR_2_]], [[VAR_3_]], [[VAR_4_]] : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
1989+ // CHECK: }
1990+
1991+ // -----
1992+ func.func @test_split_leakyrelu_movement (%arg0: tensor <1 x8 x2 xf32 >) -> (tensor <1 x2 x2 xf32 >, tensor <1 x3 x2 xf32 >, tensor <1 x3 x2 xf32 >) {
1993+ %cst = onnx.Constant dense <[2 , 3 , 3 ]> : tensor <3 xi64 >
1994+ %0:3 = " onnx.Split" (%arg0 , %cst ) {axis = 1 : si64 } : (tensor <1 x8 x2 xf32 >, tensor <3 xi64 >) -> (tensor <1 x2 x2 xf32 >, tensor <1 x3 x2 xf32 >, tensor <1 x3 x2 xf32 >)
1995+ %1 = " onnx.LeakyRelu" (%0#0 ) {onnx_node_name = " onnx.LRelu_1" , alpha = 0.2 : f32 } : (tensor <1 x2 x2 xf32 >) -> tensor <1 x2 x2 xf32 >
1996+ %2 = " onnx.LeakyRelu" (%0#1 ) {onnx_node_name = " onnx.LRelu_2" , alpha = 0.2 : f32 } : (tensor <1 x3 x2 xf32 >) -> tensor <1 x3 x2 xf32 >
1997+ %3 = " onnx.LeakyRelu" (%0#2 ) {onnx_node_name = " onnx.LRelu_3" , alpha = 0.2 : f32 } : (tensor <1 x3 x2 xf32 >) -> tensor <1 x3 x2 xf32 >
1998+ onnx.Return %1 , %2 , %3 : tensor <1 x2 x2 xf32 >, tensor <1 x3 x2 xf32 >, tensor <1 x3 x2 xf32 >
1999+ }
2000+ // CHECK-LABEL: func.func @test_split_leakyrelu_movement
2001+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
2002+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
2003+ // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.LeakyRelu"([[PARAM_0_]]) {alpha = 2.000000e-01 : f32, onnx_node_name = "onnx.LRelu_1"} : (tensor<1x8x2xf32>) -> tensor<1x8x2xf32>
2004+ // CHECK: [[VAR_2_:%.+]]:3 = "onnx.Split"([[VAR_1_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
2005+ // CHECK: onnx.Return [[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2 : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
2006+ // CHECK: }
2007+
2008+ // -----
2009+ func.func @test_split_leakyrelu_movement_different_alpha (%arg0: tensor <1 x8 x2 xf32 >) -> (tensor <1 x2 x2 xf32 >, tensor <1 x3 x2 xf32 >, tensor <1 x3 x2 xf32 >) {
2010+ %cst = onnx.Constant dense <[2 , 3 , 3 ]> : tensor <3 xi64 >
2011+ %0:3 = " onnx.Split" (%arg0 , %cst ) {axis = 1 : si64 } : (tensor <1 x8 x2 xf32 >, tensor <3 xi64 >) -> (tensor <1 x2 x2 xf32 >, tensor <1 x3 x2 xf32 >, tensor <1 x3 x2 xf32 >)
2012+ %1 = " onnx.LeakyRelu" (%0#0 ) {onnx_node_name = " onnx.LRelu_1" , alpha = 0.2 : f32 } : (tensor <1 x2 x2 xf32 >) -> tensor <1 x2 x2 xf32 >
2013+ %2 = " onnx.LeakyRelu" (%0#1 ) {onnx_node_name = " onnx.LRelu_2" , alpha = 0.2 : f32 } : (tensor <1 x3 x2 xf32 >) -> tensor <1 x3 x2 xf32 >
2014+ %3 = " onnx.LeakyRelu" (%0#2 ) {onnx_node_name = " onnx.LRelu_3" , alpha = 0.3 : f32 } : (tensor <1 x3 x2 xf32 >) -> tensor <1 x3 x2 xf32 >
2015+ onnx.Return %1 , %2 , %3 : tensor <1 x2 x2 xf32 >, tensor <1 x3 x2 xf32 >, tensor <1 x3 x2 xf32 >
2016+ }
2017+ // CHECK-LABEL: func.func @test_split_leakyrelu_movement_different_alpha
2018+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
2019+ // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
2020+ // CHECK: [[VAR_1_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
2021+ // CHECK-DAG: [[VAR_2_:%.+]] = "onnx.LeakyRelu"([[VAR_1_]]#0) {alpha = 2.000000e-01 : f32, onnx_node_name = "onnx.LRelu_1"} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
2022+ // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.LeakyRelu"([[VAR_1_]]#1) {alpha = 2.000000e-01 : f32, onnx_node_name = "onnx.LRelu_2"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
2023+ // CHECK-DAG: [[VAR_4_:%.+]] = "onnx.LeakyRelu"([[VAR_1_]]#2) {alpha = 3.000000e-01 : f32, onnx_node_name = "onnx.LRelu_3"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
2024+ // CHECK: onnx.Return [[VAR_2_]], [[VAR_3_]], [[VAR_4_]] : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
2025+ // CHECK: }
2026+
19552027// -----
19562028
19572029// Not rewriting since the operand in ConcatOp is neither DimOp nor ConstantOp.
0 commit comments