Skip to content

Commit 6af39f5

Browse files
author
Liang-Ta Wei
committed
feat: onnx.Not and onnx.Where optimization
1 parent e4cf47f commit 6af39f5

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,6 +1907,25 @@ struct PropagateBiasIntoLayerNormRewritePattern
19071907
// Rewrite pattern for Where
19081908
// =============================================================================
19091909

1910+
class NotWhereOptPattern : public OpRewritePattern<ONNXWhereOp> {
1911+
public:
1912+
using OpRewritePattern<ONNXWhereOp>::OpRewritePattern;
1913+
1914+
LogicalResult matchAndRewrite(
1915+
ONNXWhereOp onnxWhereOp, PatternRewriter &rewriter) const override {
1916+
auto notOp = onnxWhereOp.getCondition().getDefiningOp<ONNXNotOp>();
1917+
if (!notOp)
1918+
return failure();
1919+
rewriter.modifyOpInPlace(onnxWhereOp, [&]() {
1920+
onnxWhereOp.getOperation()->setOperands(
1921+
{notOp.getX(), onnxWhereOp.getY(), onnxWhereOp.getX()});
1922+
onnxWhereOp->setLoc(
1923+
rewriter.getFusedLoc({onnxWhereOp.getLoc(), notOp.getLoc()}));
1924+
});
1925+
return success();
1926+
}
1927+
};
1928+
19101929
class RemoveWhereEqualPattern : public OpRewritePattern<ONNXWhereOp> {
19111930
public:
19121931
using OpRewritePattern<ONNXWhereOp>::OpRewritePattern;
@@ -2751,6 +2770,7 @@ void ONNXWhereOp::getCanonicalizationPatterns(
27512770
RewritePatternSet &result, MLIRContext *context) {
27522771
result.insert<AlwaysFalseWherePattern>(context);
27532772
result.insert<RemoveWhereEqualPattern>(context);
2773+
result.insert<NotWhereOptPattern>(context);
27542774
}
27552775

27562776
// on the ONNXDequantizeLinearOp.

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1962,6 +1962,32 @@ func.func @test_remove_where_equal_4(%arg0: tensor<?x?xi64>) -> tensor<2xi64> {
19621962

19631963
// -----
19641964

1965+
func.func @test_not_where_opt_1(%arg0: tensor<1x10xi1>, %arg1: tensor<1x10xbf16>, %arg2: tensor<1x10xbf16>) -> tensor<1x10xbf16> {
1966+
%0 = "onnx.Not"(%arg0) : (tensor<1x10xi1>) -> tensor<1x10xi1>
1967+
%1 = "onnx.Where"(%0, %arg1, %arg2) : (tensor<1x10xi1>, tensor<1x10xbf16>, tensor<1x10xbf16>) -> tensor<1x10xbf16>
1968+
onnx.Return %1 : tensor<1x10xbf16>
1969+
// CHECK-LABEL: func.func @test_not_where_opt_1
1970+
// CHECK-SAME: ([[ARG_0_:%.+]]: tensor<1x10xi1>, [[ARG_1_:%.+]]: tensor<1x10xbf16>, [[ARG_2_:%.+]]: tensor<1x10xbf16>) -> tensor<1x10xbf16> {
1971+
// CHECK-NOT: onnx.Not
1972+
// CHECK: [[VAR_0_:%.+]] = "onnx.Where"([[ARG_0_]], [[ARG_2_]], [[ARG_1_]]) : (tensor<1x10xi1>, tensor<1x10xbf16>, tensor<1x10xbf16>) -> tensor<1x10xbf16>
1973+
// CHECK: onnx.Return [[VAR_0_]] : tensor<1x10xbf16>
1974+
}
1975+
1976+
// -----
1977+
1978+
func.func @test_not_where_opt_2(%arg0: tensor<1x10xi1>, %arg1: tensor<1x10xbf16>, %arg2: tensor<1x10xbf16>) -> (tensor<1x10xi1>, tensor<1x10xbf16>) {
1979+
%0 = "onnx.Not"(%arg0) : (tensor<1x10xi1>) -> tensor<1x10xi1>
1980+
%1 = "onnx.Where"(%0, %arg1, %arg2) : (tensor<1x10xi1>, tensor<1x10xbf16>, tensor<1x10xbf16>) -> tensor<1x10xbf16>
1981+
onnx.Return %0, %1 : tensor<1x10xi1>, tensor<1x10xbf16>
1982+
// CHECK-LABEL: func.func @test_not_where_opt_2
1983+
// CHECK-SAME: ([[ARG_0_:%.+]]: tensor<1x10xi1>, [[ARG_1_:%.+]]: tensor<1x10xbf16>, [[ARG_2_:%.+]]: tensor<1x10xbf16>) -> (tensor<1x10xi1>, tensor<1x10xbf16>) {
1984+
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Not"([[ARG_0_]]) : (tensor<1x10xi1>) -> tensor<1x10xi1>
1985+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Where"([[ARG_0_]], [[ARG_2_]], [[ARG_1_]]) : (tensor<1x10xi1>, tensor<1x10xbf16>, tensor<1x10xbf16>) -> tensor<1x10xbf16>
1986+
// CHECK: onnx.Return [[VAR_0_]], [[VAR_1_]] : tensor<1x10xi1>, tensor<1x10xbf16>
1987+
}
1988+
1989+
// -----
1990+
19651991
func.func @test_recompose_concat(%arg0: tensor<1x3x4xf32>, %arg1: tensor<1x3x4xf32> ) -> tensor<1x12x4xf32> {
19661992
%0 = "onnx.Concat"(%arg0, %arg1) {axis = 1 : si64, onnx_node_name = "onnx.Concat_0"} : (tensor<1x3x4xf32>, tensor<1x3x4xf32>) -> tensor<1x6x4xf32>
19671993
%1 = "onnx.Concat"(%0, %arg0) {axis = 1 : si64, onnx_node_name = "onnx.Concat_1"} : (tensor<1x6x4xf32>, tensor<1x3x4xf32>) -> tensor<1x9x4xf32>

0 commit comments

Comments
 (0)