@@ -44,48 +44,52 @@ class CIRAttrToValue {
4444 const mlir::TypeConverter *converter)
4545 : parentOp(parentOp), rewriter(rewriter), converter(converter) {}
4646
47- mlir::Value lowerCirAttrAsValue (mlir::Attribute attr) { return visit (attr); }
48-
4947 mlir::Value visit (mlir::Attribute attr) {
5048 return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
5149 .Case <cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>(
5250 [&](auto attrT) { return visitCirAttr (attrT); })
5351 .Default ([&](auto attrT) { return mlir::Value (); });
5452 }
5553
56- mlir::Value visitCirAttr (cir::IntAttr intAttr) {
57- mlir::Location loc = parentOp->getLoc ();
58- return rewriter.create <mlir::LLVM::ConstantOp>(
59- loc, converter->convertType (intAttr.getType ()), intAttr.getValue ());
60- }
61-
62- mlir::Value visitCirAttr (cir::FPAttr fltAttr) {
63- mlir::Location loc = parentOp->getLoc ();
64- return rewriter.create <mlir::LLVM::ConstantOp>(
65- loc, converter->convertType (fltAttr.getType ()), fltAttr.getValue ());
66- }
67-
68- mlir::Value visitCirAttr (cir::ConstPtrAttr ptrAttr) {
69- mlir::Location loc = parentOp->getLoc ();
70- if (ptrAttr.isNullValue ()) {
71- return rewriter.create <mlir::LLVM::ZeroOp>(
72- loc, converter->convertType (ptrAttr.getType ()));
73- }
74- mlir::DataLayout layout (parentOp->getParentOfType <mlir::ModuleOp>());
75- mlir::Value ptrVal = rewriter.create <mlir::LLVM::ConstantOp>(
76- loc,
77- rewriter.getIntegerType (layout.getTypeSizeInBits (ptrAttr.getType ())),
78- ptrAttr.getValue ().getInt ());
79- return rewriter.create <mlir::LLVM::IntToPtrOp>(
80- loc, converter->convertType (ptrAttr.getType ()), ptrVal);
81- }
54+ mlir::Value visitCirAttr (cir::IntAttr intAttr);
55+ mlir::Value visitCirAttr (cir::FPAttr fltAttr);
56+ mlir::Value visitCirAttr (cir::ConstPtrAttr ptrAttr);
8257
8358private:
8459 mlir::Operation *parentOp;
8560 mlir::ConversionPatternRewriter &rewriter;
8661 const mlir::TypeConverter *converter;
8762};
8863
64+ // / IntAttr visitor.
65+ mlir::Value CIRAttrToValue::visitCirAttr (cir::IntAttr intAttr) {
66+ mlir::Location loc = parentOp->getLoc ();
67+ return rewriter.create <mlir::LLVM::ConstantOp>(
68+ loc, converter->convertType (intAttr.getType ()), intAttr.getValue ());
69+ }
70+
71+ // / ConstPtrAttr visitor.
72+ mlir::Value CIRAttrToValue::visitCirAttr (cir::ConstPtrAttr ptrAttr) {
73+ mlir::Location loc = parentOp->getLoc ();
74+ if (ptrAttr.isNullValue ()) {
75+ return rewriter.create <mlir::LLVM::ZeroOp>(
76+ loc, converter->convertType (ptrAttr.getType ()));
77+ }
78+ mlir::DataLayout layout (parentOp->getParentOfType <mlir::ModuleOp>());
79+ mlir::Value ptrVal = rewriter.create <mlir::LLVM::ConstantOp>(
80+ loc, rewriter.getIntegerType (layout.getTypeSizeInBits (ptrAttr.getType ())),
81+ ptrAttr.getValue ().getInt ());
82+ return rewriter.create <mlir::LLVM::IntToPtrOp>(
83+ loc, converter->convertType (ptrAttr.getType ()), ptrVal);
84+ }
85+
86+ // / FPAttr visitor.
87+ mlir::Value CIRAttrToValue::visitCirAttr (cir::FPAttr fltAttr) {
88+ mlir::Location loc = parentOp->getLoc ();
89+ return rewriter.create <mlir::LLVM::ConstantOp>(
90+ loc, converter->convertType (fltAttr.getType ()), fltAttr.getValue ());
91+ }
92+
8993// This class handles rewriting initializer attributes for types that do not
9094// require region initialization.
9195class GlobalInitAttrRewriter {
@@ -94,8 +98,6 @@ class GlobalInitAttrRewriter {
9498 mlir::ConversionPatternRewriter &rewriter)
9599 : llvmType(type), rewriter(rewriter) {}
96100
97- mlir::Attribute rewriteInitAttr (mlir::Attribute attr) { return visit (attr); }
98-
99101 mlir::Attribute visit (mlir::Attribute attr) {
100102 return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr)
101103 .Case <cir::IntAttr, cir::FPAttr>(
@@ -137,12 +139,6 @@ struct ConvertCIRToLLVMPass
137139 StringRef getArgument () const override { return " cir-flat-to-llvm" ; }
138140};
139141
140- bool CIRToLLVMGlobalOpLowering::attrRequiresRegionInitialization (
141- mlir::Attribute attr) const {
142- // There will be more cases added later.
143- return isa<cir::ConstPtrAttr>(attr);
144- }
145-
146142// / Replace CIR global with a region initialized LLVM global and update
147143// / insertion point to the end of the initializer block.
148144void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp (
@@ -189,8 +185,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
189185 // to the appropriate value.
190186 const mlir::Location loc = op.getLoc ();
191187 setupRegionInitializedLLVMGlobalOp (op, rewriter);
192- CIRAttrToValue attrVisitor (op, rewriter, typeConverter);
193- mlir::Value value = attrVisitor. lowerCirAttrAsValue (init);
188+ CIRAttrToValue valueConverter (op, rewriter, typeConverter);
189+ mlir::Value value = valueConverter. visit (init);
194190 rewriter.create <mlir::LLVM::ReturnOp>(loc, value);
195191 return mlir::success ();
196192}
@@ -201,12 +197,6 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
201197
202198 std::optional<mlir::Attribute> init = op.getInitialValue ();
203199
204- // If we have an initializer and it requires region initialization, handle
205- // that separately
206- if (init.has_value () && attrRequiresRegionInitialization (init.value ())) {
207- return matchAndRewriteRegionInitializedGlobal (op, init.value (), rewriter);
208- }
209-
210200 // Fetch required values to create LLVM op.
211201 const mlir::Type cirSymType = op.getSymType ();
212202
@@ -231,12 +221,31 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
231221 SmallVector<mlir::NamedAttribute> attributes;
232222
233223 if (init.has_value ()) {
234- GlobalInitAttrRewriter initRewriter (llvmType, rewriter);
235- init = initRewriter.rewriteInitAttr (init.value ());
236- // If initRewriter returned a null attribute, init will have a value but
237- // the value will be null. If that happens, initRewriter didn't handle the
238- // attribute type. It probably needs to be added to GlobalInitAttrRewriter.
239- if (!init.value ()) {
224+ if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value ())) {
225+ // If a directly equivalent attribute is available, use it.
226+ init =
227+ llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value ())
228+ .Case <cir::FPAttr>([&](cir::FPAttr attr) {
229+ return rewriter.getFloatAttr (llvmType, attr.getValue ());
230+ })
231+ .Case <cir::IntAttr>([&](cir::IntAttr attr) {
232+ return rewriter.getIntegerAttr (llvmType, attr.getValue ());
233+ })
234+ .Default ([&](mlir::Attribute attr) { return mlir::Attribute (); });
235+ // If initRewriter returned a null attribute, init will have a value but
236+ // the value will be null.
237+ if (!init.value ()) {
238+ op.emitError () << " unsupported initializer '" << init.value () << " '" ;
239+ return mlir::failure ();
240+ }
241+ } else if (mlir::isa<cir::ConstPtrAttr>(init.value ())) {
242+ // TODO(cir): once LLVM's dialect has proper equivalent attributes this
243+ // should be updated. For now, we use a custom op to initialize globals
244+ // to the appropriate value.
245+ return matchAndRewriteRegionInitializedGlobal (op, init.value (), rewriter);
246+ } else {
247+ // We will only get here if new initializer types are added and this
248+ // code is not updated to handle them.
240249 op.emitError () << " unsupported initializer '" << init.value () << " '" ;
241250 return mlir::failure ();
242251 }
0 commit comments