-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][LLVM] ControlFlowToLLVM: Add 1:N type conversion support
#153937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][LLVM] ControlFlowToLLVM: Add 1:N type conversion support
#153937
Conversation
|
@llvm/pr-subscribers-mlir-llvm Author: Matthias Springer (matthias-springer) ChangesAdd support for 1:N type conversions to the Full diff: https://github.com/llvm/llvm-project/pull/153937.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index ff6d369176393..fa0023d6a0621 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -125,22 +125,33 @@ static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
return rewriter.applySignatureConversion(block, *conversion, converter);
}
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+ SmallVector<Value> result;
+ for (const auto &vals : values)
+ llvm::append_range(result, vals);
+ return result;
+}
+
/// Convert the destination block signature (if necessary) and lower the branch
/// op to llvm.br.
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
FailureOr<Block *> convertedBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
- TypeRange(adaptor.getOperands()));
+ TypeRange(flattenedAdaptor));
if (failed(convertedBlock))
return failure();
DictionaryAttr attrs = op->getAttrDictionary();
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
- op, adaptor.getOperands(), *convertedBlock);
+ op, flattenedAdaptor, *convertedBlock);
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
newOp->setAttrs(attrs);
@@ -152,29 +163,42 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
/// branch op to llvm.cond_br.
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(cf::CondBranchOp op,
- typename cf::CondBranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptorTrue =
+ flattenValues(adaptor.getTrueDestOperands());
+ SmallVector<Value> flattenedAdaptorFalse =
+ flattenValues(adaptor.getFalseDestOperands());
+ if (!llvm::hasSingleElement(adaptor.getCondition()))
+ return rewriter.notifyMatchFailure(op,
+ "expected single element condition");
FailureOr<Block *> convertedTrueBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
- TypeRange(adaptor.getTrueDestOperands()));
+ TypeRange(flattenedAdaptorTrue));
if (failed(convertedTrueBlock))
return failure();
FailureOr<Block *> convertedFalseBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
- TypeRange(adaptor.getFalseDestOperands()));
+ TypeRange(flattenedAdaptorFalse));
if (failed(convertedFalseBlock))
return failure();
DictionaryAttr attrs = op->getAttrDictionary();
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
- op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
- adaptor.getFalseDestOperands(), op.getBranchWeightsAttr(),
+ op, llvm::getSingleElement(adaptor.getCondition()),
+ flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
*convertedTrueBlock, *convertedFalseBlock);
// TODO: We should not just forward all attributes like that. But there are
- // existing Flang tests that depend on this behavior.
- newOp->setAttrs(attrs);
+ // existing Flang tests that depend on this behavior. E.g., it is incorrect
+ // to forward the `operandSegmentSizes` attribute. We cannot hard-code all
+ // attributes that must be excluded from forwarding.
+ for (NamedAttribute attr : attrs) {
+ if (attr.getName() != cf::CondBranchOp::getOperandSegmentSizeAttr())
+ newOp->setAttr(attr.getName(), attr.getValue());
+ }
return success();
}
};
diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
index c1751f282b002..6c6756f5097b4 100644
--- a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
@@ -138,3 +138,20 @@ func.func @caller(%arg0: i1, %arg1: i17) -> (i17, i1, i17) {
%res:2 = func.call @multi_return(%arg1, %arg0) : (i17, i1) -> (i17, i1)
return %res#0, %res#1, %res#0 : i17, i1, i17
}
+
+// -----
+
+// CHECK-LABEL: llvm.func @branch(
+// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18)
+// CHECK: llvm.br ^[[bb1:.*]](%[[arg1]], %[[arg2]], %[[arg0]] : i18, i18, i1)
+// CHECK: ^[[bb1]](%[[arg3:.*]]: i18, %[[arg4:.*]]: i18, %[[arg5:.*]]: i1):
+// CHECK: llvm.cond_br %[[arg5]], ^[[bb1]](%[[arg1]], %[[arg2]], %[[arg5]] : i18, i18, i1), ^[[bb2:.*]](%[[arg3]], %[[arg4]] : i18, i18)
+// CHECK: ^bb2(%{{.*}}: i18, %{{.*}}: i18):
+// CHECK: llvm.return
+func.func @branch(%arg0: i1, %arg1: i17) {
+ cf.br ^bb1(%arg1, %arg0: i17, i1)
+^bb1(%arg2: i17, %arg3: i1):
+ cf.cond_br %arg3, ^bb1(%arg1, %arg3 : i17, i1), ^bb2(%arg2 : i17)
+^bb2(%arg4: i17):
+ return
+}
diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
index fe9aa0f2a9902..9d30ae43cccc1 100644
--- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -70,6 +71,7 @@ struct TestLLVMLegalizePatternsPass
mlir::RewritePatternSet patterns(ctx);
patterns.add<TestDirectReplacementOp>(ctx, converter);
populateFuncToLLVMConversionPatterns(converter, patterns);
+ cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
// Define the conversion target used for the test.
ConversionTarget target(*ctx);
|
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd support for 1:N type conversions to the Full diff: https://github.com/llvm/llvm-project/pull/153937.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index ff6d369176393..fa0023d6a0621 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -125,22 +125,33 @@ static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
return rewriter.applySignatureConversion(block, *conversion, converter);
}
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+ SmallVector<Value> result;
+ for (const auto &vals : values)
+ llvm::append_range(result, vals);
+ return result;
+}
+
/// Convert the destination block signature (if necessary) and lower the branch
/// op to llvm.br.
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
FailureOr<Block *> convertedBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
- TypeRange(adaptor.getOperands()));
+ TypeRange(flattenedAdaptor));
if (failed(convertedBlock))
return failure();
DictionaryAttr attrs = op->getAttrDictionary();
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
- op, adaptor.getOperands(), *convertedBlock);
+ op, flattenedAdaptor, *convertedBlock);
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
newOp->setAttrs(attrs);
@@ -152,29 +163,42 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
/// branch op to llvm.cond_br.
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(cf::CondBranchOp op,
- typename cf::CondBranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptorTrue =
+ flattenValues(adaptor.getTrueDestOperands());
+ SmallVector<Value> flattenedAdaptorFalse =
+ flattenValues(adaptor.getFalseDestOperands());
+ if (!llvm::hasSingleElement(adaptor.getCondition()))
+ return rewriter.notifyMatchFailure(op,
+ "expected single element condition");
FailureOr<Block *> convertedTrueBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
- TypeRange(adaptor.getTrueDestOperands()));
+ TypeRange(flattenedAdaptorTrue));
if (failed(convertedTrueBlock))
return failure();
FailureOr<Block *> convertedFalseBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
- TypeRange(adaptor.getFalseDestOperands()));
+ TypeRange(flattenedAdaptorFalse));
if (failed(convertedFalseBlock))
return failure();
DictionaryAttr attrs = op->getAttrDictionary();
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
- op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
- adaptor.getFalseDestOperands(), op.getBranchWeightsAttr(),
+ op, llvm::getSingleElement(adaptor.getCondition()),
+ flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
*convertedTrueBlock, *convertedFalseBlock);
// TODO: We should not just forward all attributes like that. But there are
- // existing Flang tests that depend on this behavior.
- newOp->setAttrs(attrs);
+ // existing Flang tests that depend on this behavior. E.g., it is incorrect
+ // to forward the `operandSegmentSizes` attribute. We cannot hard-code all
+ // attributes that must be excluded from forwarding.
+ for (NamedAttribute attr : attrs) {
+ if (attr.getName() != cf::CondBranchOp::getOperandSegmentSizeAttr())
+ newOp->setAttr(attr.getName(), attr.getValue());
+ }
return success();
}
};
diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
index c1751f282b002..6c6756f5097b4 100644
--- a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
@@ -138,3 +138,20 @@ func.func @caller(%arg0: i1, %arg1: i17) -> (i17, i1, i17) {
%res:2 = func.call @multi_return(%arg1, %arg0) : (i17, i1) -> (i17, i1)
return %res#0, %res#1, %res#0 : i17, i1, i17
}
+
+// -----
+
+// CHECK-LABEL: llvm.func @branch(
+// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18)
+// CHECK: llvm.br ^[[bb1:.*]](%[[arg1]], %[[arg2]], %[[arg0]] : i18, i18, i1)
+// CHECK: ^[[bb1]](%[[arg3:.*]]: i18, %[[arg4:.*]]: i18, %[[arg5:.*]]: i1):
+// CHECK: llvm.cond_br %[[arg5]], ^[[bb1]](%[[arg1]], %[[arg2]], %[[arg5]] : i18, i18, i1), ^[[bb2:.*]](%[[arg3]], %[[arg4]] : i18, i18)
+// CHECK: ^bb2(%{{.*}}: i18, %{{.*}}: i18):
+// CHECK: llvm.return
+func.func @branch(%arg0: i1, %arg1: i17) {
+ cf.br ^bb1(%arg1, %arg0: i17, i1)
+^bb1(%arg2: i17, %arg3: i1):
+ cf.cond_br %arg3, ^bb1(%arg1, %arg3 : i17, i1), ^bb2(%arg2 : i17)
+^bb2(%arg4: i17):
+ return
+}
diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
index fe9aa0f2a9902..9d30ae43cccc1 100644
--- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -70,6 +71,7 @@ struct TestLLVMLegalizePatternsPass
mlir::RewritePatternSet patterns(ctx);
patterns.add<TestDirectReplacementOp>(ctx, converter);
populateFuncToLLVMConversionPatterns(converter, patterns);
+ cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
// Define the conversion target used for the test.
ConversionTarget target(*ctx);
|
df75085 to
cd98fec
Compare
gysit
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Co-authored-by: Tobias Gysi <[email protected]>
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/140/builds/28958 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/17026 Here is the relevant piece of the build log for the reference |
Fix build after #153937.
Fix build after #153937.
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/4/builds/8480 Here is the relevant piece of the build log for the reference |
Fix build after llvm#153937.
Add support for 1:N type conversions to the
ControlFlowToLLVMlowering patterns. Not applicable tocf.switchandcf.assert.