@@ -1962,6 +1962,32 @@ func.func @test_remove_where_equal_4(%arg0: tensor<?x?xi64>) -> tensor<2xi64> {
19621962
19631963// -----
19641964
1965+ func.func @test_not_where_opt_1 (%arg0: tensor <1 x10 xi1 >, %arg1: tensor <1 x10 xbf16 >, %arg2: tensor <1 x10 xbf16 >) -> tensor <1 x10 xbf16 > {
1966+ %0 = " onnx.Not" (%arg0 ) : (tensor <1 x10 xi1 >) -> tensor <1 x10 xi1 >
1967+ %1 = " onnx.Where" (%0 , %arg1 , %arg2 ) : (tensor <1 x10 xi1 >, tensor <1 x10 xbf16 >, tensor <1 x10 xbf16 >) -> tensor <1 x10 xbf16 >
1968+ onnx.Return %1 : tensor <1 x10 xbf16 >
1969+ // CHECK-LABEL: func.func @test_not_where_opt_1
1970+ // CHECK-SAME: ([[ARG_0_:%.+]]: tensor<1x10xi1>, [[ARG_1_:%.+]]: tensor<1x10xbf16>, [[ARG_2_:%.+]]: tensor<1x10xbf16>) -> tensor<1x10xbf16> {
1971+ // CHECK-NOT: onnx.Not
1972+ // CHECK: [[VAR_0_:%.+]] = "onnx.Where"([[ARG_0_]], [[ARG_2_]], [[ARG_1_]]) : (tensor<1x10xi1>, tensor<1x10xbf16>, tensor<1x10xbf16>) -> tensor<1x10xbf16>
1973+ // CHECK: onnx.Return [[VAR_0_]] : tensor<1x10xbf16>
1974+ }
1975+
1976+ // -----
1977+
1978+ func.func @test_not_where_opt_2 (%arg0: tensor <1 x10 xi1 >, %arg1: tensor <1 x10 xbf16 >, %arg2: tensor <1 x10 xbf16 >) -> (tensor <1 x10 xi1 >, tensor <1 x10 xbf16 >) {
1979+ %0 = " onnx.Not" (%arg0 ) : (tensor <1 x10 xi1 >) -> tensor <1 x10 xi1 >
1980+ %1 = " onnx.Where" (%0 , %arg1 , %arg2 ) : (tensor <1 x10 xi1 >, tensor <1 x10 xbf16 >, tensor <1 x10 xbf16 >) -> tensor <1 x10 xbf16 >
1981+ onnx.Return %0 , %1 : tensor <1 x10 xi1 >, tensor <1 x10 xbf16 >
1982+ // CHECK-LABEL: func.func @test_not_where_opt_2
1983+ // CHECK-SAME: ([[ARG_0_:%.+]]: tensor<1x10xi1>, [[ARG_1_:%.+]]: tensor<1x10xbf16>, [[ARG_2_:%.+]]: tensor<1x10xbf16>) -> (tensor<1x10xi1>, tensor<1x10xbf16>) {
1984+ // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Not"([[ARG_0_]]) : (tensor<1x10xi1>) -> tensor<1x10xi1>
1985+ // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Where"([[ARG_0_]], [[ARG_2_]], [[ARG_1_]]) : (tensor<1x10xi1>, tensor<1x10xbf16>, tensor<1x10xbf16>) -> tensor<1x10xbf16>
1986+ // CHECK: onnx.Return [[VAR_0_]], [[VAR_1_]] : tensor<1x10xi1>, tensor<1x10xbf16>
1987+ }
1988+
1989+ // -----
1990+
19651991func.func @test_recompose_concat (%arg0: tensor <1 x3 x4 xf32 >, %arg1: tensor <1 x3 x4 xf32 > ) -> tensor <1 x12 x4 xf32 > {
19661992%0 = " onnx.Concat" (%arg0 , %arg1 ) {axis = 1 : si64 , onnx_node_name = " onnx.Concat_0" } : (tensor <1 x3 x4 xf32 >, tensor <1 x3 x4 xf32 >) -> tensor <1 x6 x4 xf32 >
19671993%1 = " onnx.Concat" (%0 , %arg0 ) {axis = 1 : si64 , onnx_node_name = " onnx.Concat_1" } : (tensor <1 x6 x4 xf32 >, tensor <1 x3 x4 xf32 >) -> tensor <1 x9 x4 xf32 >
0 commit comments