Skip to content

Commit cba1518

Browse files
committed
Re-align lowering code with incubator implementation
1 parent bb41af6 commit cba1518

File tree

2 files changed

+59
-52
lines changed

2 files changed

+59
-52
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 59 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -44,48 +44,52 @@ class CIRAttrToValue {
4444
const mlir::TypeConverter *converter)
4545
: parentOp(parentOp), rewriter(rewriter), converter(converter) {}
4646

47-
mlir::Value lowerCirAttrAsValue(mlir::Attribute attr) { return visit(attr); }
48-
4947
mlir::Value visit(mlir::Attribute attr) {
5048
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
5149
.Case<cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>(
5250
[&](auto attrT) { return visitCirAttr(attrT); })
5351
.Default([&](auto attrT) { return mlir::Value(); });
5452
}
5553

56-
mlir::Value visitCirAttr(cir::IntAttr intAttr) {
57-
mlir::Location loc = parentOp->getLoc();
58-
return rewriter.create<mlir::LLVM::ConstantOp>(
59-
loc, converter->convertType(intAttr.getType()), intAttr.getValue());
60-
}
61-
62-
mlir::Value visitCirAttr(cir::FPAttr fltAttr) {
63-
mlir::Location loc = parentOp->getLoc();
64-
return rewriter.create<mlir::LLVM::ConstantOp>(
65-
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
66-
}
67-
68-
mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr) {
69-
mlir::Location loc = parentOp->getLoc();
70-
if (ptrAttr.isNullValue()) {
71-
return rewriter.create<mlir::LLVM::ZeroOp>(
72-
loc, converter->convertType(ptrAttr.getType()));
73-
}
74-
mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
75-
mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
76-
loc,
77-
rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
78-
ptrAttr.getValue().getInt());
79-
return rewriter.create<mlir::LLVM::IntToPtrOp>(
80-
loc, converter->convertType(ptrAttr.getType()), ptrVal);
81-
}
54+
mlir::Value visitCirAttr(cir::IntAttr intAttr);
55+
mlir::Value visitCirAttr(cir::FPAttr fltAttr);
56+
mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
8257

8358
private:
8459
mlir::Operation *parentOp;
8560
mlir::ConversionPatternRewriter &rewriter;
8661
const mlir::TypeConverter *converter;
8762
};
8863

64+
/// IntAttr visitor.
65+
mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) {
66+
mlir::Location loc = parentOp->getLoc();
67+
return rewriter.create<mlir::LLVM::ConstantOp>(
68+
loc, converter->convertType(intAttr.getType()), intAttr.getValue());
69+
}
70+
71+
/// ConstPtrAttr visitor.
72+
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
73+
mlir::Location loc = parentOp->getLoc();
74+
if (ptrAttr.isNullValue()) {
75+
return rewriter.create<mlir::LLVM::ZeroOp>(
76+
loc, converter->convertType(ptrAttr.getType()));
77+
}
78+
mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
79+
mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
80+
loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
81+
ptrAttr.getValue().getInt());
82+
return rewriter.create<mlir::LLVM::IntToPtrOp>(
83+
loc, converter->convertType(ptrAttr.getType()), ptrVal);
84+
}
85+
86+
/// FPAttr visitor.
87+
mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
88+
mlir::Location loc = parentOp->getLoc();
89+
return rewriter.create<mlir::LLVM::ConstantOp>(
90+
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
91+
}
92+
8993
// This class handles rewriting initializer attributes for types that do not
9094
// require region initialization.
9195
class GlobalInitAttrRewriter {
@@ -94,8 +98,6 @@ class GlobalInitAttrRewriter {
9498
mlir::ConversionPatternRewriter &rewriter)
9599
: llvmType(type), rewriter(rewriter) {}
96100

97-
mlir::Attribute rewriteInitAttr(mlir::Attribute attr) { return visit(attr); }
98-
99101
mlir::Attribute visit(mlir::Attribute attr) {
100102
return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr)
101103
.Case<cir::IntAttr, cir::FPAttr>(
@@ -137,12 +139,6 @@ struct ConvertCIRToLLVMPass
137139
StringRef getArgument() const override { return "cir-flat-to-llvm"; }
138140
};
139141

140-
bool CIRToLLVMGlobalOpLowering::attrRequiresRegionInitialization(
141-
mlir::Attribute attr) const {
142-
// There will be more cases added later.
143-
return isa<cir::ConstPtrAttr>(attr);
144-
}
145-
146142
/// Replace CIR global with a region initialized LLVM global and update
147143
/// insertion point to the end of the initializer block.
148144
void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
@@ -189,8 +185,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
189185
// to the appropriate value.
190186
const mlir::Location loc = op.getLoc();
191187
setupRegionInitializedLLVMGlobalOp(op, rewriter);
192-
CIRAttrToValue attrVisitor(op, rewriter, typeConverter);
193-
mlir::Value value = attrVisitor.lowerCirAttrAsValue(init);
188+
CIRAttrToValue valueConverter(op, rewriter, typeConverter);
189+
mlir::Value value = valueConverter.visit(init);
194190
rewriter.create<mlir::LLVM::ReturnOp>(loc, value);
195191
return mlir::success();
196192
}
@@ -201,12 +197,6 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
201197

202198
std::optional<mlir::Attribute> init = op.getInitialValue();
203199

204-
// If we have an initializer and it requires region initialization, handle
205-
// that separately
206-
if (init.has_value() && attrRequiresRegionInitialization(init.value())) {
207-
return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
208-
}
209-
210200
// Fetch required values to create LLVM op.
211201
const mlir::Type cirSymType = op.getSymType();
212202

@@ -231,12 +221,31 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
231221
SmallVector<mlir::NamedAttribute> attributes;
232222

233223
if (init.has_value()) {
234-
GlobalInitAttrRewriter initRewriter(llvmType, rewriter);
235-
init = initRewriter.rewriteInitAttr(init.value());
236-
// If initRewriter returned a null attribute, init will have a value but
237-
// the value will be null. If that happens, initRewriter didn't handle the
238-
// attribute type. It probably needs to be added to GlobalInitAttrRewriter.
239-
if (!init.value()) {
224+
if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value())) {
225+
// If a directly equivalent attribute is available, use it.
226+
init =
227+
llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value())
228+
.Case<cir::FPAttr>([&](cir::FPAttr attr) {
229+
return rewriter.getFloatAttr(llvmType, attr.getValue());
230+
})
231+
.Case<cir::IntAttr>([&](cir::IntAttr attr) {
232+
return rewriter.getIntegerAttr(llvmType, attr.getValue());
233+
})
234+
.Default([&](mlir::Attribute attr) { return mlir::Attribute(); });
235+
// If initRewriter returned a null attribute, init will have a value but
236+
// the value will be null.
237+
if (!init.value()) {
238+
op.emitError() << "unsupported initializer '" << init.value() << "'";
239+
return mlir::failure();
240+
}
241+
} else if (mlir::isa<cir::ConstPtrAttr>(init.value())) {
242+
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
243+
// should be updated. For now, we use a custom op to initialize globals
244+
// to the appropriate value.
245+
return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
246+
} else {
247+
// We will only get here if new initializer types are added and this
248+
// code is not updated to handle them.
240249
op.emitError() << "unsupported initializer '" << init.value() << "'";
241250
return mlir::failure();
242251
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ class CIRToLLVMGlobalOpLowering
3636
mlir::ConversionPatternRewriter &rewriter) const override;
3737

3838
private:
39-
bool attrRequiresRegionInitialization(mlir::Attribute attr) const;
40-
4139
mlir::LogicalResult matchAndRewriteRegionInitializedGlobal(
4240
cir::GlobalOp op, mlir::Attribute init,
4341
mlir::ConversionPatternRewriter &rewriter) const;

0 commit comments

Comments
 (0)