@@ -188,14 +188,15 @@ class CIRAttrToValue {
188188
189189 mlir::Value visit (mlir::Attribute attr) {
190190 return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
191- .Case <cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr ,
192- cir::ConstVectorAttr , cir::ConstPtrAttr , cir::ZeroAttr>(
193- [&](auto attrT) { return visitCirAttr (attrT); })
191+ .Case <cir::IntAttr, cir::FPAttr, cir::ConstComplexAttr ,
192+ cir::ConstArrayAttr , cir::ConstVectorAttr , cir::ConstPtrAttr,
193+ cir::ZeroAttr>( [&](auto attrT) { return visitCirAttr (attrT); })
194194 .Default ([&](auto attrT) { return mlir::Value (); });
195195 }
196196
197197 mlir::Value visitCirAttr (cir::IntAttr intAttr);
198198 mlir::Value visitCirAttr (cir::FPAttr fltAttr);
199+ mlir::Value visitCirAttr (cir::ConstComplexAttr complexAttr);
199200 mlir::Value visitCirAttr (cir::ConstPtrAttr ptrAttr);
200201 mlir::Value visitCirAttr (cir::ConstArrayAttr attr);
201202 mlir::Value visitCirAttr (cir::ConstVectorAttr attr);
@@ -226,6 +227,42 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) {
226227 loc, converter->convertType (intAttr.getType ()), intAttr.getValue ());
227228}
228229
230+ // / FPAttr visitor.
231+ mlir::Value CIRAttrToValue::visitCirAttr (cir::FPAttr fltAttr) {
232+ mlir::Location loc = parentOp->getLoc ();
233+ return rewriter.create <mlir::LLVM::ConstantOp>(
234+ loc, converter->convertType (fltAttr.getType ()), fltAttr.getValue ());
235+ }
236+
237+ // / ConstComplexAttr visitor.
238+ mlir::Value CIRAttrToValue::visitCirAttr (cir::ConstComplexAttr complexAttr) {
239+ auto complexType = mlir::cast<cir::ComplexType>(complexAttr.getType ());
240+ auto complexElemTy = complexType.getElementType ();
241+ auto complexElemLLVMTy = converter->convertType (complexElemTy);
242+
243+ mlir::Attribute components[2 ];
244+ if (const auto intType = mlir::dyn_cast<cir::IntType>(complexElemTy)) {
245+ components[0 ] = rewriter.getIntegerAttr (
246+ complexElemLLVMTy,
247+ mlir::cast<cir::IntAttr>(complexAttr.getReal ()).getValue ());
248+ components[1 ] = rewriter.getIntegerAttr (
249+ complexElemLLVMTy,
250+ mlir::cast<cir::IntAttr>(complexAttr.getImag ()).getValue ());
251+ } else {
252+ components[0 ] = rewriter.getFloatAttr (
253+ complexElemLLVMTy,
254+ mlir::cast<cir::FPAttr>(complexAttr.getReal ()).getValue ());
255+ components[1 ] = rewriter.getFloatAttr (
256+ complexElemLLVMTy,
257+ mlir::cast<cir::FPAttr>(complexAttr.getImag ()).getValue ());
258+ }
259+
260+ mlir::Location loc = parentOp->getLoc ();
261+ return rewriter.create <mlir::LLVM::ConstantOp>(
262+ loc, converter->convertType (complexAttr.getType ()),
263+ rewriter.getArrayAttr (components));
264+ }
265+
229266// / ConstPtrAttr visitor.
230267mlir::Value CIRAttrToValue::visitCirAttr (cir::ConstPtrAttr ptrAttr) {
231268 mlir::Location loc = parentOp->getLoc ();
@@ -241,13 +278,6 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
241278 loc, converter->convertType (ptrAttr.getType ()), ptrVal);
242279}
243280
244- // / FPAttr visitor.
245- mlir::Value CIRAttrToValue::visitCirAttr (cir::FPAttr fltAttr) {
246- mlir::Location loc = parentOp->getLoc ();
247- return rewriter.create <mlir::LLVM::ConstantOp>(
248- loc, converter->convertType (fltAttr.getType ()), fltAttr.getValue ());
249- }
250-
251281// ConstArrayAttr visitor
252282mlir::Value CIRAttrToValue::visitCirAttr (cir::ConstArrayAttr attr) {
253283 mlir::Type llvmTy = converter->convertType (attr.getType ());
@@ -341,9 +371,11 @@ class GlobalInitAttrRewriter {
341371 mlir::Attribute visitCirAttr (cir::IntAttr attr) {
342372 return rewriter.getIntegerAttr (llvmType, attr.getValue ());
343373 }
374+
344375 mlir::Attribute visitCirAttr (cir::FPAttr attr) {
345376 return rewriter.getFloatAttr (llvmType, attr.getValue ());
346377 }
378+
347379 mlir::Attribute visitCirAttr (cir::BoolAttr attr) {
348380 return rewriter.getBoolAttr (attr.getValue ());
349381 }
@@ -986,7 +1018,7 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
9861018 mlir::ConversionPatternRewriter &rewriter) const {
9871019 // TODO: Generalize this handling when more types are needed here.
9881020 assert ((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
989- cir::ZeroAttr>(init)));
1021+ cir::ConstComplexAttr, cir:: ZeroAttr>(init)));
9901022
9911023 // TODO(cir): once LLVM's dialect has proper equivalent attributes this
9921024 // should be updated. For now, we use a custom op to initialize globals
@@ -1039,7 +1071,8 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
10391071 return mlir::failure ();
10401072 }
10411073 } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
1042- cir::ConstPtrAttr, cir::ZeroAttr>(init.value ())) {
1074+ cir::ConstPtrAttr, cir::ConstComplexAttr,
1075+ cir::ZeroAttr>(init.value ())) {
10431076 // TODO(cir): once LLVM's dialect has proper equivalent attributes this
10441077 // should be updated. For now, we use a custom op to initialize globals
10451078 // to the appropriate value.
@@ -1571,6 +1604,14 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
15711604 converter.addConversion ([&](cir::BF16Type type) -> mlir::Type {
15721605 return mlir::BFloat16Type::get (type.getContext ());
15731606 });
1607+ converter.addConversion ([&](cir::ComplexType type) -> mlir::Type {
1608+ // A complex type is lowered to an LLVM struct that contains the real and
1609+ // imaginary part as data fields.
1610+ mlir::Type elementTy = converter.convertType (type.getElementType ());
1611+ mlir::Type structFields[2 ] = {elementTy, elementTy};
1612+ return mlir::LLVM::LLVMStructType::getLiteral (type.getContext (),
1613+ structFields);
1614+ });
15741615 converter.addConversion ([&](cir::FuncType type) -> std::optional<mlir::Type> {
15751616 auto result = converter.convertType (type.getReturnType ());
15761617 llvm::SmallVector<mlir::Type> arguments;
0 commit comments