Skip to content

Commit 50d0012

Browse files
committed
[AutoBump] Merge with fixes of d57479c (Feb 19)
2 parents c14ec74 + d57479c commit 50d0012

File tree

3 files changed

+17
-17
lines changed

3 files changed

+17
-17
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,9 +1196,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
11961196
}];
11971197

11981198
let arguments = (ins
1199-
Tosa_I1Tensor:$pred,
1200-
Tosa_Tensor:$on_true,
1201-
Tosa_Tensor:$on_false
1199+
Tosa_I1Tensor:$input1,
1200+
Tosa_Tensor:$input2,
1201+
Tosa_Tensor:$input3
12021202
);
12031203

12041204
let results = (outs
@@ -1208,7 +1208,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
12081208
let hasFolder = 1;
12091209

12101210
let assemblyFormat = [{
1211-
operands attr-dict `:` `(` type($pred) `,` type($on_true) `,` type($on_false)
1211+
operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
12121212
`)` `->` type($output)
12131213
}];
12141214
}

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,12 @@ struct SelectLogicalNotOptimization : public OpRewritePattern<tosa::SelectOp> {
221221
using OpRewritePattern::OpRewritePattern;
222222
LogicalResult matchAndRewrite(tosa::SelectOp op,
223223
PatternRewriter &rewriter) const override {
224-
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
224+
auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
225225
if (!notOp)
226226
return failure();
227227
rewriter.modifyOpInPlace(op, [&]() {
228228
op.getOperation()->setOperands(
229-
{notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
229+
{notOp.getInput1(), op.getInput3(), op.getInput2()});
230230
});
231231
return success();
232232
}
@@ -247,7 +247,7 @@ struct SelectToClampOptimization : public OpRewritePattern<tosa::SelectOp> {
247247
LogicalResult matchAndRewrite(tosa::SelectOp op,
248248
PatternRewriter &rewriter) const override {
249249

250-
auto geq = op.getPred().getDefiningOp<tosa::GreaterEqualOp>();
250+
auto geq = op.getInput1().getDefiningOp<tosa::GreaterEqualOp>();
251251
if (!geq) {
252252
return rewriter.notifyMatchFailure(op,
253253
"Predicate is not a GreaterEqualOp");
@@ -297,8 +297,8 @@ struct SelectToClampOptimization : public OpRewritePattern<tosa::SelectOp> {
297297
return a.getSplatValue<APFloat>() == b.getSplatValue<APFloat>();
298298
};
299299

300-
auto onFalse = op.getOnFalse();
301-
auto onTrue = op.getOnTrue();
300+
auto onFalse = op.getInput3();
301+
auto onTrue = op.getInput2();
302302
DenseElementsAttr onFalseAttr;
303303
DenseElementsAttr onTrueAttr;
304304

@@ -1722,18 +1722,18 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
17221722
}
17231723

17241724
OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1725-
if (getOnTrue() == getOnFalse())
1726-
return getOnTrue();
1725+
if (getInput2() == getInput3())
1726+
return getInput2();
17271727

17281728
auto predicate =
1729-
llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
1729+
llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
17301730
if (!predicate)
17311731
return {};
17321732

17331733
if (!predicate.isSplat())
17341734
return {};
1735-
return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1736-
: getOnFalse();
1735+
return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
1736+
: getInput3();
17371737
}
17381738

17391739
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {

mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
169169
LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
170170
PatternRewriter &rewriter) const override {
171171

172-
Value input1 = tosaOp.getPred();
173-
Value input2 = tosaOp.getOnTrue();
174-
Value input3 = tosaOp.getOnFalse();
172+
Value input1 = tosaOp.getInput1();
173+
Value input2 = tosaOp.getInput2();
174+
Value input3 = tosaOp.getInput3();
175175
Value output = tosaOp.getResult();
176176

177177
auto outputType = dyn_cast<RankedTensorType>(output.getType());

0 commit comments

Comments
 (0)