2424#include " mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
2525#include " mlir/Target/LLVMIR/Export.h"
2626#include " mlir/Transforms/DialectConversion.h"
27- #include " clang/CIR/Dialect/IR/CIRAttrVisitor.h"
2827#include " clang/CIR/Dialect/IR/CIRDialect.h"
2928#include " clang/CIR/MissingFeatures.h"
3029#include " clang/CIR/Passes.h"
30+ #include " llvm/ADT/TypeSwitch.h"
3131#include " llvm/IR/Module.h"
3232#include " llvm/Support/TimeProfiler.h"
3333
@@ -37,63 +37,78 @@ using namespace llvm;
3737namespace cir {
3838namespace direct {
3939
40- class CIRAttrToValue : public CirAttrVisitor <CIRAttrToValue, mlir::Value> {
40+ class CIRAttrToValue {
4141public:
4242 CIRAttrToValue (mlir::Operation *parentOp,
4343 mlir::ConversionPatternRewriter &rewriter,
4444 const mlir::TypeConverter *converter)
4545 : parentOp(parentOp), rewriter(rewriter), converter(converter) {}
4646
47- mlir::Value lowerCirAttrAsValue (mlir::Attribute attr) { return visit (attr); }
48-
49- mlir::Value visitCirIntAttr (cir::IntAttr intAttr) {
50- mlir::Location loc = parentOp->getLoc ();
51- return rewriter.create <mlir::LLVM::ConstantOp>(
52- loc, converter->convertType (intAttr.getType ()), intAttr.getValue ());
53- }
54-
55- mlir::Value visitCirFPAttr (cir::FPAttr fltAttr) {
56- mlir::Location loc = parentOp->getLoc ();
57- return rewriter.create <mlir::LLVM::ConstantOp>(
58- loc, converter->convertType (fltAttr.getType ()), fltAttr.getValue ());
47+ mlir::Value visit (mlir::Attribute attr) {
48+ return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
49+ .Case <cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>(
50+ [&](auto attrT) { return visitCirAttr (attrT); })
51+ .Default ([&](auto attrT) { return mlir::Value (); });
5952 }
6053
61- mlir::Value visitCirConstPtrAttr (cir::ConstPtrAttr ptrAttr) {
62- mlir::Location loc = parentOp->getLoc ();
63- if (ptrAttr.isNullValue ()) {
64- return rewriter.create <mlir::LLVM::ZeroOp>(
65- loc, converter->convertType (ptrAttr.getType ()));
66- }
67- mlir::DataLayout layout (parentOp->getParentOfType <mlir::ModuleOp>());
68- mlir::Value ptrVal = rewriter.create <mlir::LLVM::ConstantOp>(
69- loc,
70- rewriter.getIntegerType (layout.getTypeSizeInBits (ptrAttr.getType ())),
71- ptrAttr.getValue ().getInt ());
72- return rewriter.create <mlir::LLVM::IntToPtrOp>(
73- loc, converter->convertType (ptrAttr.getType ()), ptrVal);
74- }
54+ mlir::Value visitCirAttr (cir::IntAttr intAttr);
55+ mlir::Value visitCirAttr (cir::FPAttr fltAttr);
56+ mlir::Value visitCirAttr (cir::ConstPtrAttr ptrAttr);
7557
7658private:
7759 mlir::Operation *parentOp;
7860 mlir::ConversionPatternRewriter &rewriter;
7961 const mlir::TypeConverter *converter;
8062};
8163
64+ // / IntAttr visitor.
65+ mlir::Value CIRAttrToValue::visitCirAttr (cir::IntAttr intAttr) {
66+ mlir::Location loc = parentOp->getLoc ();
67+ return rewriter.create <mlir::LLVM::ConstantOp>(
68+ loc, converter->convertType (intAttr.getType ()), intAttr.getValue ());
69+ }
70+
71+ // / ConstPtrAttr visitor.
72+ mlir::Value CIRAttrToValue::visitCirAttr (cir::ConstPtrAttr ptrAttr) {
73+ mlir::Location loc = parentOp->getLoc ();
74+ if (ptrAttr.isNullValue ()) {
75+ return rewriter.create <mlir::LLVM::ZeroOp>(
76+ loc, converter->convertType (ptrAttr.getType ()));
77+ }
78+ mlir::DataLayout layout (parentOp->getParentOfType <mlir::ModuleOp>());
79+ mlir::Value ptrVal = rewriter.create <mlir::LLVM::ConstantOp>(
80+ loc, rewriter.getIntegerType (layout.getTypeSizeInBits (ptrAttr.getType ())),
81+ ptrAttr.getValue ().getInt ());
82+ return rewriter.create <mlir::LLVM::IntToPtrOp>(
83+ loc, converter->convertType (ptrAttr.getType ()), ptrVal);
84+ }
85+
86+ // / FPAttr visitor.
87+ mlir::Value CIRAttrToValue::visitCirAttr (cir::FPAttr fltAttr) {
88+ mlir::Location loc = parentOp->getLoc ();
89+ return rewriter.create <mlir::LLVM::ConstantOp>(
90+ loc, converter->convertType (fltAttr.getType ()), fltAttr.getValue ());
91+ }
92+
8293// This class handles rewriting initializer attributes for types that do not
8394// require region initialization.
84- class GlobalInitAttrRewriter
85- : public CirAttrVisitor<GlobalInitAttrRewriter, mlir::Attribute> {
95+ class GlobalInitAttrRewriter {
8696public:
8797 GlobalInitAttrRewriter (mlir::Type type,
8898 mlir::ConversionPatternRewriter &rewriter)
8999 : llvmType(type), rewriter(rewriter) {}
90100
91- mlir::Attribute rewriteInitAttr (mlir::Attribute attr) { return visit (attr); }
101+ mlir::Attribute visit (mlir::Attribute attr) {
102+ return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr)
103+ .Case <cir::IntAttr, cir::FPAttr>(
104+ [&](auto attrT) { return visitCirAttr (attrT); })
105+ .Default ([&](auto attrT) { return mlir::Attribute (); });
106+ }
92107
93- mlir::Attribute visitCirIntAttr (cir::IntAttr attr) {
108+ mlir::Attribute visitCirAttr (cir::IntAttr attr) {
94109 return rewriter.getIntegerAttr (llvmType, attr.getValue ());
95110 }
96- mlir::Attribute visitCirFPAttr (cir::FPAttr attr) {
111+ mlir::Attribute visitCirAttr (cir::FPAttr attr) {
97112 return rewriter.getFloatAttr (llvmType, attr.getValue ());
98113 }
99114
@@ -124,12 +139,6 @@ struct ConvertCIRToLLVMPass
124139 StringRef getArgument () const override { return " cir-flat-to-llvm" ; }
125140};
126141
127- bool CIRToLLVMGlobalOpLowering::attrRequiresRegionInitialization (
128- mlir::Attribute attr) const {
129- // There will be more cases added later.
130- return isa<cir::ConstPtrAttr>(attr);
131- }
132-
133142// / Replace CIR global with a region initialized LLVM global and update
134143// / insertion point to the end of the initializer block.
135144void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp (
@@ -176,8 +185,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
176185 // to the appropriate value.
177186 const mlir::Location loc = op.getLoc ();
178187 setupRegionInitializedLLVMGlobalOp (op, rewriter);
179- CIRAttrToValue attrVisitor (op, rewriter, typeConverter);
180- mlir::Value value = attrVisitor. lowerCirAttrAsValue (init);
188+ CIRAttrToValue valueConverter (op, rewriter, typeConverter);
189+ mlir::Value value = valueConverter. visit (init);
181190 rewriter.create <mlir::LLVM::ReturnOp>(loc, value);
182191 return mlir::success ();
183192}
@@ -188,12 +197,6 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
188197
189198 std::optional<mlir::Attribute> init = op.getInitialValue ();
190199
191- // If we have an initializer and it requires region initialization, handle
192- // that separately
193- if (init.has_value () && attrRequiresRegionInitialization (init.value ())) {
194- return matchAndRewriteRegionInitializedGlobal (op, init.value (), rewriter);
195- }
196-
197200 // Fetch required values to create LLVM op.
198201 const mlir::Type cirSymType = op.getSymType ();
199202
@@ -218,12 +221,25 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
218221 SmallVector<mlir::NamedAttribute> attributes;
219222
220223 if (init.has_value ()) {
221- GlobalInitAttrRewriter initRewriter (llvmType, rewriter);
222- init = initRewriter.rewriteInitAttr (init.value ());
223- // If initRewriter returned a null attribute, init will have a value but
224- // the value will be null. If that happens, initRewriter didn't handle the
225- // attribute type. It probably needs to be added to GlobalInitAttrRewriter.
226- if (!init.value ()) {
224+ if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value ())) {
225+ GlobalInitAttrRewriter initRewriter (llvmType, rewriter);
226+ init = initRewriter.visit (init.value ());
227+ // If initRewriter returned a null attribute, init will have a value but
228+ // the value will be null. If that happens, initRewriter didn't handle the
229+ // attribute type. It probably needs to be added to
230+ // GlobalInitAttrRewriter.
231+ if (!init.value ()) {
232+ op.emitError () << " unsupported initializer '" << init.value () << " '" ;
233+ return mlir::failure ();
234+ }
235+ } else if (mlir::isa<cir::ConstPtrAttr>(init.value ())) {
236+ // TODO(cir): once LLVM's dialect has proper equivalent attributes this
237+ // should be updated. For now, we use a custom op to initialize globals
238+ // to the appropriate value.
239+ return matchAndRewriteRegionInitializedGlobal (op, init.value (), rewriter);
240+ } else {
241+ // We will only get here if new initializer types are added and this
242+ // code is not updated to handle them.
227243 op.emitError () << " unsupported initializer '" << init.value () << " '" ;
228244 return mlir::failure ();
229245 }
0 commit comments