Skip to content

Commit f7b09ad

Browse files
[mlir][LLVM] ArithToLLVM: Add 1:N support for arith.select lowering (#153944)
Add 1:N support for the `arith.select` lowering. Only cases where the entire true/false value is selected are supported.
1 parent 127ba53 commit f7b09ad

File tree

3 files changed

+54
-0
lines changed

3 files changed

+54
-0
lines changed

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
238238
ConversionPatternRewriter &rewriter) const override;
239239
};
240240

241+
struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
242+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
243+
using Adaptor =
244+
typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
245+
246+
LogicalResult
247+
matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
248+
ConversionPatternRewriter &rewriter) const override;
249+
};
250+
241251
} // namespace
242252

243253
//===----------------------------------------------------------------------===//
@@ -479,6 +489,32 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
479489
rewriter);
480490
}
481491

492+
//===----------------------------------------------------------------------===//
493+
// SelectOpOneToNLowering
494+
//===----------------------------------------------------------------------===//
495+
496+
/// Pattern for arith.select where the true/false values lower to multiple
497+
/// SSA values (1:N conversion). This pattern generates multiple arith.select
498+
/// than can be lowered by the 1:1 arith.select pattern.
499+
LogicalResult SelectOpOneToNLowering::matchAndRewrite(
500+
arith::SelectOp op, Adaptor adaptor,
501+
ConversionPatternRewriter &rewriter) const {
502+
// In case of a 1:1 conversion, the 1:1 pattern will match.
503+
if (llvm::hasSingleElement(adaptor.getTrueValue()))
504+
return rewriter.notifyMatchFailure(
505+
op, "not a 1:N conversion, 1:1 pattern will match");
506+
if (!op.getCondition().getType().isInteger(1))
507+
return rewriter.notifyMatchFailure(op,
508+
"non-i1 conditions are not supported");
509+
SmallVector<Value> results;
510+
for (auto [trueValue, falseValue] :
511+
llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
512+
results.push_back(arith::SelectOp::create(
513+
rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
514+
rewriter.replaceOpWithMultiple(op, {results});
515+
return success();
516+
}
517+
482518
//===----------------------------------------------------------------------===//
483519
// Pass Definition
484520
//===----------------------------------------------------------------------===//
@@ -587,6 +623,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
587623
RemSIOpLowering,
588624
RemUIOpLowering,
589625
SelectOpLowering,
626+
SelectOpOneToNLowering,
590627
ShLIOpLowering,
591628
ShRSIOpLowering,
592629
ShRUIOpLowering,
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -test-llvm-legalize-patterns="allow-pattern-rollback=0" -split-input-file | FileCheck %s
3+
4+
// CHECK-LABEL: llvm.func @arith_select(
5+
// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18, %[[arg3:.*]]: i18, %[[arg4:.*]]: i18) -> !llvm.struct<(i18, i18)>
6+
// CHECK: %[[select0:.*]] = llvm.select %[[arg0]], %[[arg1]], %[[arg3]] : i1, i18
7+
// CHECK: %[[select1:.*]] = llvm.select %[[arg0]], %[[arg2]], %[[arg4]] : i1, i18
8+
// CHECK: %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)>
9+
// CHECK: %[[i1:.*]] = llvm.insertvalue %[[select0]], %[[i0]][0] : !llvm.struct<(i18, i18)>
10+
// CHECK: %[[i2:.*]] = llvm.insertvalue %[[select1]], %[[i1]][1] : !llvm.struct<(i18, i18)>
11+
// CHECK: llvm.return %[[i2]]
12+
func.func @arith_select(%arg0: i1, %arg1: i17, %arg2: i17) -> (i17) {
13+
%0 = arith.select %arg0, %arg1, %arg2 : i17
14+
return %0 : i17
15+
}

mlir/test/lib/Dialect/LLVM/TestPatterns.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
910
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
1011
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
1112
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@@ -70,6 +71,7 @@ struct TestLLVMLegalizePatternsPass
7071
// Populate patterns.
7172
mlir::RewritePatternSet patterns(ctx);
7273
patterns.add<TestDirectReplacementOp>(ctx, converter);
74+
arith::populateArithToLLVMConversionPatterns(converter, patterns);
7375
populateFuncToLLVMConversionPatterns(converter, patterns);
7476
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
7577

0 commit comments

Comments
 (0)