Skip to content

Commit f8f23e8

Browse files
[mlir][LLVM] ControlFlowToLLVM: Add 1:N type conversion support (#153937)
Add support for 1:N type conversions to the `ControlFlowToLLVM` lowering patterns. Not applicable to `cf.switch` and `cf.assert`. --------- Co-authored-by: Tobias Gysi <[email protected]>
1 parent f0967fc commit f8f23e8

File tree

3 files changed

+49
-11
lines changed

3 files changed

+49
-11
lines changed

mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,22 +125,33 @@ static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
125125
return rewriter.applySignatureConversion(block, *conversion, converter);
126126
}
127127

128+
/// Flatten the given value ranges into a single vector of values.
129+
static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
130+
SmallVector<Value> result;
131+
for (const ValueRange &vals : values)
132+
llvm::append_range(result, vals);
133+
return result;
134+
}
135+
128136
/// Convert the destination block signature (if necessary) and lower the branch
129137
/// op to llvm.br.
130138
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
131139
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
140+
using Adaptor =
141+
typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
132142

133143
LogicalResult
134-
matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
144+
matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
135145
ConversionPatternRewriter &rewriter) const override {
146+
SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
136147
FailureOr<Block *> convertedBlock =
137148
getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
138-
TypeRange(adaptor.getOperands()));
149+
TypeRange(flattenedAdaptor));
139150
if (failed(convertedBlock))
140151
return failure();
141152
DictionaryAttr attrs = op->getAttrDictionary();
142153
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
143-
op, adaptor.getOperands(), *convertedBlock);
154+
op, flattenedAdaptor, *convertedBlock);
144155
// TODO: We should not just forward all attributes like that. But there are
145156
// existing Flang tests that depend on this behavior.
146157
newOp->setAttrs(attrs);
@@ -152,29 +163,37 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
152163
/// branch op to llvm.cond_br.
153164
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
154165
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
166+
using Adaptor =
167+
typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
155168

156169
LogicalResult
157-
matchAndRewrite(cf::CondBranchOp op,
158-
typename cf::CondBranchOp::Adaptor adaptor,
170+
matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
159171
ConversionPatternRewriter &rewriter) const override {
172+
SmallVector<Value> flattenedAdaptorTrue =
173+
flattenValues(adaptor.getTrueDestOperands());
174+
SmallVector<Value> flattenedAdaptorFalse =
175+
flattenValues(adaptor.getFalseDestOperands());
176+
if (!llvm::hasSingleElement(adaptor.getCondition()))
177+
return rewriter.notifyMatchFailure(op,
178+
"expected single element condition");
160179
FailureOr<Block *> convertedTrueBlock =
161180
getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
162-
TypeRange(adaptor.getTrueDestOperands()));
181+
TypeRange(flattenedAdaptorTrue));
163182
if (failed(convertedTrueBlock))
164183
return failure();
165184
FailureOr<Block *> convertedFalseBlock =
166185
getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
167-
TypeRange(adaptor.getFalseDestOperands()));
186+
TypeRange(flattenedAdaptorFalse));
168187
if (failed(convertedFalseBlock))
169188
return failure();
170-
DictionaryAttr attrs = op->getAttrDictionary();
189+
DictionaryAttr attrs = op->getDiscardableAttrDictionary();
171190
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
172-
op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
173-
adaptor.getFalseDestOperands(), op.getBranchWeightsAttr(),
191+
op, llvm::getSingleElement(adaptor.getCondition()),
192+
flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
174193
*convertedTrueBlock, *convertedFalseBlock);
175194
// TODO: We should not just forward all attributes like that. But there are
176195
// existing Flang tests that depend on this behavior.
177-
newOp->setAttrs(attrs);
196+
newOp->setDiscardableAttrs(attrs);
178197
return success();
179198
}
180199
};

mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,20 @@ func.func @caller(%arg0: i1, %arg1: i17) -> (i17, i1, i17) {
138138
%res:2 = func.call @multi_return(%arg1, %arg0) : (i17, i1) -> (i17, i1)
139139
return %res#0, %res#1, %res#0 : i17, i1, i17
140140
}
141+
142+
// -----
143+
144+
// CHECK-LABEL: llvm.func @branch(
145+
// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18)
146+
// CHECK: llvm.br ^[[bb1:.*]](%[[arg1]], %[[arg2]], %[[arg0]] : i18, i18, i1)
147+
// CHECK: ^[[bb1]](%[[arg3:.*]]: i18, %[[arg4:.*]]: i18, %[[arg5:.*]]: i1):
148+
// CHECK: llvm.cond_br %[[arg5]], ^[[bb1]](%[[arg1]], %[[arg2]], %[[arg5]] : i18, i18, i1), ^[[bb2:.*]](%[[arg3]], %[[arg4]] : i18, i18)
149+
// CHECK: ^bb2(%{{.*}}: i18, %{{.*}}: i18):
150+
// CHECK: llvm.return
151+
func.func @branch(%arg0: i1, %arg1: i17) {
152+
cf.br ^bb1(%arg1, %arg0: i17, i1)
153+
^bb1(%arg2: i17, %arg3: i1):
154+
cf.cond_br %arg3, ^bb1(%arg1, %arg3 : i17, i1), ^bb2(%arg2 : i17)
155+
^bb2(%arg4: i17):
156+
return
157+
}

mlir/test/lib/Dialect/LLVM/TestPatterns.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
910
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
1011
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1112
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -70,6 +71,7 @@ struct TestLLVMLegalizePatternsPass
7071
mlir::RewritePatternSet patterns(ctx);
7172
patterns.add<TestDirectReplacementOp>(ctx, converter);
7273
populateFuncToLLVMConversionPatterns(converter, patterns);
74+
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
7375

7476
// Define the conversion target used for the test.
7577
ConversionTarget target(*ctx);

0 commit comments

Comments
 (0)