diff --git a/clang/include/clang/CIR/Passes.h b/clang/include/clang/CIR/Passes.h index 3f8a174aac0c..01227e376de9 100644 --- a/clang/include/clang/CIR/Passes.h +++ b/clang/include/clang/CIR/Passes.h @@ -18,6 +18,11 @@ #include namespace cir { +/// Create a pass for transforming CIR operations to more 'scf' dialect-friendly +/// forms. It rewrites operations that aren't supported by 'scf', such as breaks +/// and continues. +std::unique_ptr createMLIRLoweringPreparePass(); + /// Create a pass for lowering from MLIR builtin dialects such as `Affine` and /// `Std`, to the LLVM dialect for codegen. std::unique_ptr createConvertMLIRToLLVMPass(); diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/CMakeLists.txt b/clang/lib/CIR/Lowering/ThroughMLIR/CMakeLists.txt index 8c2631ab57d8..3fdb514aa47a 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/CMakeLists.txt +++ b/clang/lib/CIR/Lowering/ThroughMLIR/CMakeLists.txt @@ -9,6 +9,7 @@ add_clang_library(clangCIRLoweringThroughMLIR LowerCIRLoopToSCF.cpp LowerCIRToMLIR.cpp LowerMLIRToLLVM.cpp + MLIRLoweringPrepare.cpp DEPENDS MLIRCIROpsIncGen diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index 5d2b4180571a..075a13464b7e 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -48,19 +48,17 @@ #include "mlir/Transforms/DialectConversion.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/Dialect/IR/CIRTypes.h" +#include "clang/CIR/Interfaces/CIRLoopOpInterface.h" #include "clang/CIR/LowerToLLVM.h" #include "clang/CIR/LowerToMLIR.h" #include "clang/CIR/LoweringHelpers.h" #include "clang/CIR/Passes.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/ErrorHandling.h" -#include "clang/CIR/Interfaces/CIRLoopOpInterface.h" -#include "clang/CIR/LowerToLLVM.h" -#include "clang/CIR/Passes.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Value.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/TimeProfiler.h" using namespace cir; @@ -946,8 +944,8 @@ class CIRScopeOpLowering : public mlir::OpConversionPattern { } else { // For scopes with results, use scf.execute_region SmallVector types; - if (mlir::failed( - getTypeConverter()->convertTypes(scopeOp->getResultTypes(), types))) + if (mlir::failed(getTypeConverter()->convertTypes( + scopeOp->getResultTypes(), types))) return mlir::failure(); auto exec = rewriter.create(scopeOp.getLoc(), types); @@ -1515,28 +1513,117 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern { } }; +class CIRSwitchOpLowering : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::SwitchOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.setInsertionPointAfter(op); + llvm::SmallVector cases; + if (!op.isSimpleForm(cases)) + llvm_unreachable("NYI"); + + llvm::SmallVector caseValues; + // Maps the index of a CaseOp in `cases`, to the index in `caseValues`. + // This is necessary because some CaseOp might carry 0 or multiple values. + llvm::DenseMap indexMap; + caseValues.reserve(cases.size()); + for (auto [i, caseOp] : llvm::enumerate(cases)) { + switch (caseOp.getKind()) { + case CaseOpKind::Equal: { + auto valueAttr = caseOp.getValue()[0]; + auto value = cast(valueAttr); + indexMap[i] = caseValues.size(); + caseValues.push_back(value.getUInt()); + break; + } + case CaseOpKind::Default: + break; + case CaseOpKind::Range: + case CaseOpKind::Anyof: + llvm_unreachable("NYI"); + } + } + + auto operand = adaptor.getOperands()[0]; + // `scf.index_switch` expects an index of type `index`. + auto indexType = mlir::IndexType::get(getContext()); + auto indexCast = rewriter.create( + op.getLoc(), indexType, operand); + auto indexSwitch = rewriter.create( + op.getLoc(), mlir::TypeRange{}, indexCast, caseValues, cases.size()); + + bool metDefault = false; + for (auto [i, caseOp] : llvm::enumerate(cases)) { + auto ®ion = caseOp.getRegion(); + switch (caseOp.getKind()) { + case CaseOpKind::Equal: { + auto &caseRegion = indexSwitch.getCaseRegions()[indexMap[i]]; + rewriter.inlineRegionBefore(region, caseRegion, caseRegion.end()); + break; + } + case CaseOpKind::Default: { + auto &defaultRegion = indexSwitch.getDefaultRegion(); + rewriter.inlineRegionBefore(region, defaultRegion, defaultRegion.end()); + metDefault = true; + break; + } + case CaseOpKind::Range: + case CaseOpKind::Anyof: + llvm_unreachable("NYI"); + } + } + + // `scf.index_switch` expects its default region to contain exactly one + // block. If we don't have a default region in `cir.switch`, we need to + // supply it here. + if (!metDefault) { + auto &defaultRegion = indexSwitch.getDefaultRegion(); + mlir::Block *block = + rewriter.createBlock(&defaultRegion, defaultRegion.end()); + rewriter.setInsertionPointToEnd(block); + rewriter.create(op.getLoc()); + } + + // The final `cir.break` should be replaced to `scf.yield`. + // After MLIRLoweringPrepare pass, every case must end with a `cir.break`. + for (auto ®ion : indexSwitch.getCaseRegions()) { + auto &lastBlock = region.back(); + auto &lastOp = lastBlock.back(); + assert(isa(lastOp)); + rewriter.setInsertionPointAfter(&lastOp); + rewriter.replaceOpWithNewOp(&lastOp); + } + + rewriter.replaceOp(op, indexSwitch); + + return mlir::success(); + } +}; + void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter) { patterns.add(patterns.getContext()); - patterns - .add(converter, patterns.getContext()); + patterns.add< + CIRSwitchOpLowering, CIRATanOpLowering, CIRCmpOpLowering, + CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering, + CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering, + CIRAllocaOpLowering, CIRFuncOpLowering, CIRBrCondOpLowering, + CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering, + CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering, + CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering, + CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering, + CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering, + CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering, CIRSinOpLowering, + CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering, + CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering, + CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering, + CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering, + CIRACosOpLowering, CIRASinOpLowering, CIRUnreachableOpLowering, + CIRTanOpLowering, CIRTrapOpLowering>(converter, patterns.getContext()); } static mlir::TypeConverter prepareTypeConverter() { @@ -1610,7 +1697,7 @@ void ConvertCIRToMLIRPass::runOnOperation() { mlir::ModuleOp theModule = getOperation(); auto converter = prepareTypeConverter(); - + mlir::RewritePatternSet patterns(&getContext()); populateCIRLoopToSCFConversionPatterns(patterns, converter); @@ -1628,10 +1715,11 @@ void ConvertCIRToMLIRPass::runOnOperation() { // cir dialect, for example the `cir.continue`. If we marked cir as illegal // here, then MLIR would think any remaining `cir.continue` indicates a // failure, which is not what we want. - - patterns.add(converter, context); - if (mlir::failed(mlir::applyPartialConversion(theModule, target, + patterns.add(converter, context); + + if (mlir::failed(mlir::applyPartialConversion(theModule, target, std::move(patterns)))) { signalPassFailure(); } @@ -1646,6 +1734,7 @@ mlir::ModuleOp lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule, mlir::PassManager pm(mlirCtx); + pm.addPass(createMLIRLoweringPreparePass()); pm.addPass(createConvertCIRToMLIRPass()); pm.addPass(createConvertMLIRToLLVMPass()); @@ -1712,6 +1801,8 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule, llvm::TimeTraceScope scope("Lower CIR To MLIR"); mlir::PassManager pm(mlirCtx); + + pm.addPass(createMLIRLoweringPreparePass()); pm.addPass(createConvertCIRToMLIRPass()); auto result = !mlir::failed(pm.run(theModule)); diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/MLIRLoweringPrepare.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/MLIRLoweringPrepare.cpp new file mode 100644 index 000000000000..d450de8e7b07 --- /dev/null +++ b/clang/lib/CIR/Lowering/ThroughMLIR/MLIRLoweringPrepare.cpp @@ -0,0 +1,112 @@ +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h" +#include "clang/CIR/Dialect/IR/CIRDialect.h" + +using namespace llvm; +using namespace cir; + +namespace cir { + +struct MLIRLoweringPrepare + : public mlir::PassWrapper> { + // `scf.index_switch` requires that switch branches do not fall through. + // We need to copy the next branch's body when the current `cir.case` does not + // terminate with a break. + void removeFallthrough(llvm::SmallVector &cases); + + void runOnOp(mlir::Operation *op); + void runOnOperation() final; + + StringRef getDescription() const override { + return "Rewrite CIR module to be more 'scf' dialect-friendly"; + } + + StringRef getArgument() const override { return "mlir-lowering-prepare"; } +}; + +// `scf.index_switch` requires that switch branches do not fall through. +// We need to copy the next branch's body when the current `cir.case` does not +// terminate with a break. +void MLIRLoweringPrepare::removeFallthrough(llvm::SmallVector &cases) { + CIRBaseBuilderTy builder(getContext()); + // Note we enumerate in the reverse order, to facilitate the cloning. + for (auto it = cases.rbegin(); it != cases.rend(); it++) { + auto caseOp = *it; + auto ®ion = caseOp.getRegion(); + auto &lastBlock = region.back(); + mlir::Operation &last = lastBlock.back(); + if (isa(last)) + continue; + + // The last op must be a `cir.yield`. As it falls through, we copy the + // previous case's body to this one. + if (!isa(last)) { + caseOp->dump(); + continue; + } + assert(isa(last)); + + // If there's no previous case, we can simply change the yield into a break. + if (it == cases.rbegin()) { + builder.setInsertionPointAfter(&last); + builder.create(last.getLoc()); + last.erase(); + continue; + } + + auto prevIt = it; + --prevIt; + CaseOp &prev = *prevIt; + auto &prevRegion = prev.getRegion(); + mlir::IRMapping mapping; + builder.cloneRegionBefore(prevRegion, region, region.end()); + + // We inline the block to the end. + // This is required because `scf.index_switch` expects that each of its + // region contains a single block. + mlir::Block *cloned = lastBlock.getNextNode(); + for (auto it = cloned->begin(); it != cloned->end();) { + auto next = it; + next++; + it->moveBefore(&last); + it = next; + } + cloned->erase(); + last.erase(); + } +} + +void MLIRLoweringPrepare::runOnOp(mlir::Operation *op) { + if (auto switchOp = dyn_cast(op)) { + llvm::SmallVector cases; + if (!switchOp.isSimpleForm(cases)) + llvm_unreachable("NYI"); + + removeFallthrough(cases); + return; + } + llvm_unreachable("unexpected op type"); +} + +void MLIRLoweringPrepare::runOnOperation() { + auto module = getOperation(); + + llvm::SmallVector opsToTransform; + module->walk([&](mlir::Operation *op) { + if (isa(op)) + opsToTransform.push_back(op); + }); + + for (auto *op : opsToTransform) + runOnOp(op); +} + +std::unique_ptr createMLIRLoweringPreparePass() { + return std::make_unique(); +} + +} // namespace cir \ No newline at end of file diff --git a/clang/test/CIR/Lowering/ThroughMLIR/switch.c b/clang/test/CIR/Lowering/ThroughMLIR/switch.c new file mode 100644 index 000000000000..02756b3a2536 --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/switch.c @@ -0,0 +1,50 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s + +void fallthrough() { + int i = 0; + switch (i) { + case 2: + i++; + case 3: + i++; + break; + case 8: + i++; + } + + // This should copy the `i++; break` in case 3 to case 2. + + // CHECK: memref.alloca_scope { + // CHECK: %[[I:.+]] = memref.load %alloca[] + // CHECK: %[[CASTED:.+]] = arith.index_cast %[[I]] + // CHECK: scf.index_switch %[[CASTED]] + // CHECK: case 2 { + // CHECK: %[[I:.+]] = memref.load %alloca[] + // CHECK: %[[ONE:.+]] = arith.constant 1 + // CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]] + // CHECK: memref.store %[[ADD]], %alloca[] + // CHECK: %[[I:.+]] = memref.load %alloca[] + // CHECK: %[[ONE:.+]] = arith.constant 1 + // CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]] + // CHECK: memref.store %[[ADD]], %alloca[] + // CHECK: scf.yield + // CHECK: } + // CHECK: case 3 { + // CHECK: %[[I:.+]] = memref.load %alloca[] + // CHECK: %[[ONE:.+]] = arith.constant 1 + // CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]] + // CHECK: memref.store %[[ADD]], %alloca[] + // CHECK: scf.yield + // CHECK: } + // CHECK: case 8 { + // CHECK: %[[I:.+]] = memref.load %alloca[] + // CHECK: %[[ONE:.+]] = arith.constant 1 + // CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]] + // CHECK: memref.store %[[ADD]], %alloca[] + // CHECK: scf.yield + // CHECK: } + // CHECK: default { + // CHECK: } + // CHECK: } +} \ No newline at end of file