From cd98fecd87724b20d9ef7e6437e7856aaaa2eeac Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 16 Aug 2025 08:26:08 +0000 Subject: [PATCH 1/2] [mlir][LLVM] `ControlFlowToLLVM`: Add 1:N type conversion support --- .../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 41 ++++++++++++++----- .../MemRefToLLVM/type-conversion.mlir | 17 ++++++++ mlir/test/lib/Dialect/LLVM/TestPatterns.cpp | 2 + 3 files changed, 49 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index ff6d369176393..0b8a3af471516 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -125,22 +125,33 @@ static FailureOr getConvertedBlock(ConversionPatternRewriter &rewriter, return rewriter.applySignatureConversion(block, *conversion, converter); } +/// Flatten the given value ranges into a single vector of values. +static SmallVector flattenValues(ArrayRef values) { + SmallVector result; + for (const auto &vals : values) + llvm::append_range(result, vals); + return result; +} + /// Convert the destination block signature (if necessary) and lower the branch /// op to llvm.br. struct BranchOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using Adaptor = + typename ConvertOpToLLVMPattern::OneToNOpAdaptor; LogicalResult - matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, + matchAndRewrite(cf::BranchOp op, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector flattenedAdaptor = flattenValues(adaptor.getOperands()); FailureOr convertedBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(), - TypeRange(adaptor.getOperands())); + TypeRange(flattenedAdaptor)); if (failed(convertedBlock)) return failure(); DictionaryAttr attrs = op->getAttrDictionary(); Operation *newOp = rewriter.replaceOpWithNewOp( - op, adaptor.getOperands(), *convertedBlock); + op, flattenedAdaptor, *convertedBlock); // TODO: We should not just forward all attributes like that. But there are // existing Flang tests that depend on this behavior. newOp->setAttrs(attrs); @@ -152,29 +163,37 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern { /// branch op to llvm.cond_br. struct CondBranchOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using Adaptor = + typename ConvertOpToLLVMPattern::OneToNOpAdaptor; LogicalResult - matchAndRewrite(cf::CondBranchOp op, - typename cf::CondBranchOp::Adaptor adaptor, + matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector flattenedAdaptorTrue = + flattenValues(adaptor.getTrueDestOperands()); + SmallVector flattenedAdaptorFalse = + flattenValues(adaptor.getFalseDestOperands()); + if (!llvm::hasSingleElement(adaptor.getCondition())) + return rewriter.notifyMatchFailure(op, + "expected single element condition"); FailureOr convertedTrueBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(), - TypeRange(adaptor.getTrueDestOperands())); + TypeRange(flattenedAdaptorTrue)); if (failed(convertedTrueBlock)) return failure(); FailureOr convertedFalseBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(), - TypeRange(adaptor.getFalseDestOperands())); + TypeRange(flattenedAdaptorFalse)); if (failed(convertedFalseBlock)) return failure(); - DictionaryAttr attrs = op->getAttrDictionary(); + DictionaryAttr attrs = op->getDiscardableAttrDictionary(); auto newOp = rewriter.replaceOpWithNewOp( - op, adaptor.getCondition(), adaptor.getTrueDestOperands(), - adaptor.getFalseDestOperands(), op.getBranchWeightsAttr(), + op, llvm::getSingleElement(adaptor.getCondition()), + flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(), *convertedTrueBlock, *convertedFalseBlock); // TODO: We should not just forward all attributes like that. But there are // existing Flang tests that depend on this behavior. - newOp->setAttrs(attrs); + newOp->setDiscardableAttrs(attrs); return success(); } }; diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir index c1751f282b002..6c6756f5097b4 100644 --- a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir @@ -138,3 +138,20 @@ func.func @caller(%arg0: i1, %arg1: i17) -> (i17, i1, i17) { %res:2 = func.call @multi_return(%arg1, %arg0) : (i17, i1) -> (i17, i1) return %res#0, %res#1, %res#0 : i17, i1, i17 } + +// ----- + +// CHECK-LABEL: llvm.func @branch( +// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18) +// CHECK: llvm.br ^[[bb1:.*]](%[[arg1]], %[[arg2]], %[[arg0]] : i18, i18, i1) +// CHECK: ^[[bb1]](%[[arg3:.*]]: i18, %[[arg4:.*]]: i18, %[[arg5:.*]]: i1): +// CHECK: llvm.cond_br %[[arg5]], ^[[bb1]](%[[arg1]], %[[arg2]], %[[arg5]] : i18, i18, i1), ^[[bb2:.*]](%[[arg3]], %[[arg4]] : i18, i18) +// CHECK: ^bb2(%{{.*}}: i18, %{{.*}}: i18): +// CHECK: llvm.return +func.func @branch(%arg0: i1, %arg1: i17) { + cf.br ^bb1(%arg1, %arg0: i17, i1) +^bb1(%arg2: i17, %arg3: i1): + cf.cond_br %arg3, ^bb1(%arg1, %arg3 : i17, i1), ^bb2(%arg2 : i17) +^bb2(%arg4: i17): + return +} diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp index fe9aa0f2a9902..9d30ae43cccc1 100644 --- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -70,6 +71,7 @@ struct TestLLVMLegalizePatternsPass mlir::RewritePatternSet patterns(ctx); patterns.add(ctx, converter); populateFuncToLLVMConversionPatterns(converter, patterns); + cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); // Define the conversion target used for the test. ConversionTarget target(*ctx); From 70cd7279311c7cd4f7f4885bfd952b8318b99598 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 16 Aug 2025 12:34:25 +0200 Subject: [PATCH 2/2] Update mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp Co-authored-by: Tobias Gysi --- mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index 0b8a3af471516..e1bbeb996d730 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -128,7 +128,7 @@ static FailureOr getConvertedBlock(ConversionPatternRewriter &rewriter, /// Flatten the given value ranges into a single vector of values. static SmallVector flattenValues(ArrayRef values) { SmallVector result; - for (const auto &vals : values) + for (const ValueRange &vals : values) llvm::append_range(result, vals); return result; }