@@ -188,15 +188,17 @@ class CIRAttrToValue {
188188
189189 mlir::Value visit (mlir::Attribute attr) {
190190 return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
191- .Case <cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr, cir::ConstPtrAttr,
192- cir::ZeroAttr>([&](auto attrT) { return visitCirAttr (attrT); })
191+ .Case <cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr,
192+ cir::ConstVectorAttr, cir::ConstPtrAttr, cir::ZeroAttr>(
193+ [&](auto attrT) { return visitCirAttr (attrT); })
193194 .Default ([&](auto attrT) { return mlir::Value (); });
194195 }
195196
196197 mlir::Value visitCirAttr (cir::IntAttr intAttr);
197198 mlir::Value visitCirAttr (cir::FPAttr fltAttr);
198199 mlir::Value visitCirAttr (cir::ConstPtrAttr ptrAttr);
199200 mlir::Value visitCirAttr (cir::ConstArrayAttr attr);
201+ mlir::Value visitCirAttr (cir::ConstVectorAttr attr);
200202 mlir::Value visitCirAttr (cir::ZeroAttr attr);
201203
202204private:
@@ -275,6 +277,33 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
275277 return result;
276278}
277279
280+ // / ConstVectorAttr visitor.
281+ mlir::Value CIRAttrToValue::visitCirAttr (cir::ConstVectorAttr attr) {
282+ const mlir::Type llvmTy = converter->convertType (attr.getType ());
283+ const mlir::Location loc = parentOp->getLoc ();
284+
285+ SmallVector<mlir::Attribute> mlirValues;
286+ for (const mlir::Attribute elementAttr : attr.getElts ()) {
287+ mlir::Attribute mlirAttr;
288+ if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(elementAttr)) {
289+ mlirAttr = rewriter.getIntegerAttr (
290+ converter->convertType (intAttr.getType ()), intAttr.getValue ());
291+ } else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(elementAttr)) {
292+ mlirAttr = rewriter.getFloatAttr (
293+ converter->convertType (floatAttr.getType ()), floatAttr.getValue ());
294+ } else {
295+ llvm_unreachable (
296+ " vector constant with an element that is neither an int nor a float" );
297+ }
298+ mlirValues.push_back (mlirAttr);
299+ }
300+
301+ return rewriter.create <mlir::LLVM::ConstantOp>(
302+ loc, llvmTy,
303+ mlir::DenseElementsAttr::get (mlir::cast<mlir::ShapedType>(llvmTy),
304+ mlirValues));
305+ }
306+
278307// / ZeroAttr visitor.
279308mlir::Value CIRAttrToValue::visitCirAttr (cir::ZeroAttr attr) {
280309 mlir::Location loc = parentOp->getLoc ();
@@ -888,7 +917,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
888917 cir::GlobalOp op, mlir::Attribute init,
889918 mlir::ConversionPatternRewriter &rewriter) const {
890919 // TODO: Generalize this handling when more types are needed here.
891- assert ((isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>(init)));
920+ assert ((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
921+ cir::ZeroAttr>(init)));
892922
893923 // TODO(cir): once LLVM's dialect has proper equivalent attributes this
894924 // should be updated. For now, we use a custom op to initialize globals
@@ -941,8 +971,8 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
941971 op.emitError () << " unsupported initializer '" << init.value () << " '" ;
942972 return mlir::failure ();
943973 }
944- } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>(
945- init.value ())) {
974+ } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
975+ cir::ConstPtrAttr, cir::ZeroAttr>( init.value ())) {
946976 // TODO(cir): once LLVM's dialect has proper equivalent attributes this
947977 // should be updated. For now, we use a custom op to initialize globals
948978 // to the appropriate value.
0 commit comments