Skip to content

Commit b93a2e6

Browse files
committed
[CIR] Make CIR-to-LLVM a one shot conversion
This had to fix memory and conversion bugs due to now immediate conversion patterns and no longer present original MLIR.
1 parent 68ca1b8 commit b93a2e6

File tree

2 files changed

+86
-43
lines changed

2 files changed

+86
-43
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2160,7 +2160,7 @@ def CIR_GlobalOp : CIR_Op<"global", [
21602160
cir::GlobalOp op, mlir::Attribute init,
21612161
mlir::ConversionPatternRewriter &rewriter) const;
21622162

2163-
void setupRegionInitializedLLVMGlobalOp(
2163+
mlir::LLVM::GlobalOp setupRegionInitializedLLVMGlobalOp(
21642164
cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const;
21652165

21662166
mutable mlir::LLVM::ComdatOp comdatOp = nullptr;

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 85 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -716,8 +716,7 @@ mlir::LogicalResult CIRToLLVMIsFPClassOpLowering::matchAndRewrite(
716716
mlir::LogicalResult CIRToLLVMAssumeOpLowering::matchAndRewrite(
717717
cir::AssumeOp op, OpAdaptor adaptor,
718718
mlir::ConversionPatternRewriter &rewriter) const {
719-
auto cond = adaptor.getPredicate();
720-
rewriter.replaceOpWithNewOp<mlir::LLVM::AssumeOp>(op, cond);
719+
rewriter.replaceOpWithNewOp<mlir::LLVM::AssumeOp>(op, adaptor.getPredicate());
721720
return mlir::success();
722721
}
723722

@@ -1130,11 +1129,11 @@ mlir::LogicalResult CIRToLLVMBrCondOpLowering::matchAndRewrite(
11301129
// ZExtOp and if so, delete it if it has a single use.
11311130
assert(!cir::MissingFeatures::zextOp());
11321131

1133-
mlir::Value i1Condition = adaptor.getCond();
1134-
1132+
11351133
rewriter.replaceOpWithNewOp<mlir::LLVM::CondBrOp>(
1136-
brOp, i1Condition, brOp.getDestTrue(), adaptor.getDestOperandsTrue(),
1137-
brOp.getDestFalse(), adaptor.getDestOperandsFalse());
1134+
brOp, adaptor.getCond(), brOp.getDestTrue(),
1135+
adaptor.getDestOperandsTrue(), brOp.getDestFalse(),
1136+
adaptor.getDestOperandsFalse());
11381137

11391138
return mlir::success();
11401139
}
@@ -1942,12 +1941,12 @@ mlir::LogicalResult CIRToLLVMFuncOpLowering::matchAndRewriteAlias(
19421941
lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes);
19431942

19441943
mlir::Location loc = op.getLoc();
1944+
mlir::OpBuilder builder(op.getContext());
19451945
auto aliasOp = rewriter.replaceOpWithNewOp<mlir::LLVM::AliasOp>(
19461946
op, ty, convertLinkage(op.getLinkage()), op.getName(), op.getDsoLocal(),
19471947
/*threadLocal=*/false, attributes);
19481948

19491949
// Create the alias body
1950-
mlir::OpBuilder builder(op.getContext());
19511950
mlir::Block *block = builder.createBlock(&aliasOp.getInitializerRegion());
19521951
builder.setInsertionPointToStart(block);
19531952
// The type of AddressOfOp is always a pointer.
@@ -2053,7 +2052,8 @@ mlir::LogicalResult CIRToLLVMGetGlobalOpLowering::matchAndRewrite(
20532052

20542053
/// Replace CIR global with a region initialized LLVM global and update
20552054
/// insertion point to the end of the initializer block.
2056-
void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
2055+
mlir::LLVM::GlobalOp
2056+
CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
20572057
cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const {
20582058
const mlir::Type llvmType =
20592059
convertTypeForMemory(*getTypeConverter(), dataLayout, op.getSymType());
@@ -2080,6 +2080,7 @@ void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
20802080
isDsoLocal, isThreadLocal, comdatAttr, attributes);
20812081
newGlobalOp.getRegion().emplaceBlock();
20822082
rewriter.setInsertionPointToEnd(newGlobalOp.getInitializerBlock());
2083+
return newGlobalOp;
20832084
}
20842085

20852086
mlir::LogicalResult
@@ -2097,8 +2098,9 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
20972098
// should be updated. For now, we use a custom op to initialize globals
20982099
// to the appropriate value.
20992100
const mlir::Location loc = op.getLoc();
2100-
setupRegionInitializedLLVMGlobalOp(op, rewriter);
2101-
CIRAttrToValue valueConverter(op, rewriter, typeConverter);
2101+
mlir::LLVM::GlobalOp newGlobalOp =
2102+
setupRegionInitializedLLVMGlobalOp(op, rewriter);
2103+
CIRAttrToValue valueConverter(newGlobalOp, rewriter, typeConverter);
21022104
mlir::Value value = valueConverter.visit(init);
21032105
mlir::LLVM::ReturnOp::create(rewriter, loc, value);
21042106
return mlir::success();
@@ -2795,42 +2797,45 @@ mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
27952797
mlir::LogicalResult CIRToLLVMSelectOpLowering::matchAndRewrite(
27962798
cir::SelectOp op, OpAdaptor adaptor,
27972799
mlir::ConversionPatternRewriter &rewriter) const {
2798-
auto getConstantBool = [](mlir::Value value) -> cir::BoolAttr {
2799-
auto definingOp = value.getDefiningOp<cir::ConstantOp>();
2800-
if (!definingOp)
2801-
return {};
28022800

2803-
auto constValue = definingOp.getValueAttr<cir::BoolAttr>();
2804-
if (!constValue)
2805-
return {};
2801+
// Helper to extract boolean constant value
2802+
auto getConstantBool = [](mlir::Value value) -> std::optional<bool> {
2803+
auto constOp = value.getDefiningOp<mlir::LLVM::ConstantOp>();
2804+
if (!constOp)
2805+
return std::nullopt;
28062806

2807-
return constValue;
2807+
auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(constOp.getValue());
2808+
if (!intAttr)
2809+
return std::nullopt;
2810+
2811+
return !intAttr.getValue().isZero();
28082812
};
28092813

2814+
mlir::Value condition = adaptor.getCondition();
2815+
mlir::Value trueValue = adaptor.getTrueValue();
2816+
mlir::Value falseValue = adaptor.getFalseValue();
2817+
28102818
// Two special cases in the LLVMIR codegen of select op:
2811-
// - select %0, %1, false => and %0, %1
2812-
// - select %0, true, %1 => or %0, %1
2819+
// - select %cond, %val, false => and %cond, %val
2820+
// - select %cond, true, %val => or %cond, %val
28132821
if (mlir::isa<cir::BoolType>(op.getTrueValue().getType())) {
2814-
cir::BoolAttr trueValue = getConstantBool(op.getTrueValue());
2815-
cir::BoolAttr falseValue = getConstantBool(op.getFalseValue());
2816-
if (falseValue && !falseValue.getValue()) {
2817-
// select %0, %1, false => and %0, %1
2818-
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, adaptor.getCondition(),
2819-
adaptor.getTrueValue());
2822+
// Optimization: select %cond, %val, false => and %cond, %val
2823+
std::optional<bool> falseConst = getConstantBool(falseValue);
2824+
if (falseConst && !*falseConst) {
2825+
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, condition, trueValue);
28202826
return mlir::success();
28212827
}
2822-
if (trueValue && trueValue.getValue()) {
2823-
// select %0, true, %1 => or %0, %1
2824-
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, adaptor.getCondition(),
2825-
adaptor.getFalseValue());
2828+
// Optimization: select %cond, true, %val => or %cond, %val
2829+
std::optional<bool> trueConst = getConstantBool(trueValue);
2830+
if (trueConst && *trueConst) {
2831+
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, condition, falseValue);
28262832
return mlir::success();
28272833
}
28282834
}
28292835

2830-
mlir::Value llvmCondition = adaptor.getCondition();
2831-
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
2832-
op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue());
2833-
2836+
// Default case: emit standard LLVM select
2837+
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(op, condition, trueValue,
2838+
falseValue);
28342839
return mlir::success();
28352840
}
28362841

@@ -3008,13 +3013,12 @@ static void buildCtorDtorList(
30083013
mlir::LLVM::ReturnOp::create(builder, loc, result);
30093014
}
30103015

3011-
// The applyPartialConversion function traverses blocks in the dominance order,
3012-
// so it does not lower and operations that are not reachachable from the
3013-
// operations passed in as arguments. Since we do need to lower such code in
3014-
// order to avoid verification errors occur, we cannot just pass the module op
3015-
// to applyPartialConversion. We must build a set of unreachable ops and
3016-
// explicitly add them, along with the module, to the vector we pass to
3017-
// applyPartialConversion.
3016+
// The applyFullConversion function performs a full conversion that legalizes
3017+
// all operations. It traverses all operations including unreachable blocks, so
3018+
// we need to collect unreachable operations and explicitly add them along with
3019+
// the module to ensure they are converted. We use one-shot conversion mode
3020+
// (allowPatternRollback = false) for better performance by avoiding rollback
3021+
// state maintenance.
30183022
//
30193023
// For instance, this CIR code:
30203024
//
@@ -3135,7 +3139,10 @@ void ConvertCIRToLLVMPass::runOnOperation() {
31353139
ops.push_back(module);
31363140
collectUnreachable(module, ops);
31373141

3138-
if (failed(applyPartialConversion(ops, target, std::move(patterns))))
3142+
mlir::ConversionConfig config;
3143+
config.allowPatternRollback = false;
3144+
3145+
if (failed(applyFullConversion(ops, target, std::move(patterns), config)))
31393146
signalPassFailure();
31403147

31413148
// Emit the llvm.global_ctors array.
@@ -3750,11 +3757,31 @@ mlir::LogicalResult CIRToLLVMComplexRealOpLowering::matchAndRewrite(
37503757
mlir::ConversionPatternRewriter &rewriter) const {
37513758
mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
37523759
mlir::Value operand = adaptor.getOperand();
3760+
3761+
// FIXME:
3762+
// Check if we're extracting from a ComplexCreate that was already lowered
3763+
// Pattern: insertvalue(insertvalue(undef, real, 0), imag, 1) -> just use
3764+
// 'real'
3765+
if (auto secondInsert = operand.getDefiningOp<mlir::LLVM::InsertValueOp>()) {
3766+
if (secondInsert.getPosition() == llvm::ArrayRef<int64_t>{1}) {
3767+
if (auto firstInsert = secondInsert.getContainer()
3768+
.getDefiningOp<mlir::LLVM::InsertValueOp>()) {
3769+
if (firstInsert.getPosition() == llvm::ArrayRef<int64_t>{0}) {
3770+
// This is the pattern we're looking for - return the real component
3771+
// directly
3772+
rewriter.replaceOp(op, firstInsert.getValue());
3773+
return mlir::success();
3774+
}
3775+
}
3776+
}
3777+
}
3778+
37533779
if (mlir::isa<cir::ComplexType>(op.getOperand().getType())) {
37543780
operand = mlir::LLVM::ExtractValueOp::create(
37553781
rewriter, op.getLoc(), resultLLVMTy, operand,
37563782
llvm::ArrayRef<std::int64_t>{0});
37573783
}
3784+
37583785
rewriter.replaceOp(op, operand);
37593786
return mlir::success();
37603787
}
@@ -3815,6 +3842,22 @@ mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite(
38153842
mlir::Value operand = adaptor.getOperand();
38163843
mlir::Location loc = op.getLoc();
38173844

3845+
// FIXME:
3846+
// Check if we're extracting from a ComplexCreate that was already lowered
3847+
// Pattern: insertvalue(insertvalue(undef, real, 0), imag, 1) -> just use
3848+
// 'imag'
3849+
if (auto secondInsert = operand.getDefiningOp<mlir::LLVM::InsertValueOp>()) {
3850+
if (secondInsert.getPosition() == llvm::ArrayRef<int64_t>{1}) {
3851+
if (secondInsert.getContainer()
3852+
.getDefiningOp<mlir::LLVM::InsertValueOp>()) {
3853+
// This is the pattern we're looking for - return the imag component
3854+
// directly
3855+
rewriter.replaceOp(op, secondInsert.getValue());
3856+
return mlir::success();
3857+
}
3858+
}
3859+
}
3860+
38183861
if (mlir::isa<cir::ComplexType>(op.getOperand().getType())) {
38193862
operand = mlir::LLVM::ExtractValueOp::create(
38203863
rewriter, loc, resultLLVMTy, operand, llvm::ArrayRef<std::int64_t>{1});

0 commit comments

Comments
 (0)