diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 2cddbee97dff2..a2ee31bfd4637 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1196,9 +1196,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> { }]; let arguments = (ins - Tosa_I1Tensor:$pred, - Tosa_Tensor:$on_true, - Tosa_Tensor:$on_false + Tosa_I1Tensor:$input1, + Tosa_Tensor:$input2, + Tosa_Tensor:$input3 ); let results = (outs @@ -1208,7 +1208,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> { let hasFolder = 1; let assemblyFormat = [{ - operands attr-dict `:` `(` type($pred) `,` type($on_true) `,` type($on_false) + operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3) `)` `->` type($output) }]; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 22447e49fb3ea..3d4d7ccf5ebb2 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -221,12 +221,12 @@ struct SelectLogicalNotOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::SelectOp op, PatternRewriter &rewriter) const override { - auto notOp = op.getPred().getDefiningOp(); + auto notOp = op.getInput1().getDefiningOp(); if (!notOp) return failure(); rewriter.modifyOpInPlace(op, [&]() { op.getOperation()->setOperands( - {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()}); + {notOp.getInput1(), op.getInput3(), op.getInput2()}); }); return success(); } @@ -247,7 +247,7 @@ struct SelectToClampOptimization : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::SelectOp op, PatternRewriter &rewriter) const override { - auto geq = op.getPred().getDefiningOp(); + auto geq = op.getInput1().getDefiningOp(); if (!geq) { return rewriter.notifyMatchFailure(op, "Predicate is not a GreaterEqualOp"); @@ -297,8 +297,8 @@ struct SelectToClampOptimization : public OpRewritePattern { return a.getSplatValue() == b.getSplatValue(); }; - auto onFalse = op.getOnFalse(); - auto onTrue = op.getOnTrue(); + auto onFalse = op.getInput3(); + auto onTrue = op.getInput2(); DenseElementsAttr onFalseAttr; DenseElementsAttr onTrueAttr; @@ -1722,18 +1722,18 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { } OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { - if (getOnTrue() == getOnFalse()) - return getOnTrue(); + if (getInput2() == getInput3()) + return getInput2(); auto predicate = - llvm::dyn_cast_if_present(adaptor.getPred()); + llvm::dyn_cast_if_present(adaptor.getInput1()); if (!predicate) return {}; if (!predicate.isSplat()) return {}; - return predicate.getSplatValue().getBoolValue() ? getOnTrue() - : getOnFalse(); + return predicate.getSplatValue().getBoolValue() ? getInput2() + : getInput3(); } OpFoldResult TileOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp index 79afc75fd6c8e..87b2a2695351b 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -169,9 +169,9 @@ struct ConvertTosaOp : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::SelectOp tosaOp, PatternRewriter &rewriter) const override { - Value input1 = tosaOp.getPred(); - Value input2 = tosaOp.getOnTrue(); - Value input3 = tosaOp.getOnFalse(); + Value input1 = tosaOp.getInput1(); + Value input2 = tosaOp.getInput2(); + Value input3 = tosaOp.getInput3(); Value output = tosaOp.getResult(); auto outputType = dyn_cast(output.getType());