@@ -697,6 +697,16 @@ class CIRConstantOpLowering
697697 }
698698 return mlir::DenseElementsAttr::get (
699699 mlir::cast<mlir::ShapedType>(mlirType), mlirValues);
700+ } else if (auto zeroAttr = mlir::dyn_cast<cir::ZeroAttr>(cirAttr)) {
701+ (void )zeroAttr;
702+ return rewriter.getZeroAttr (mlirType);
703+ } else if (auto complexAttr = mlir::dyn_cast<cir::ComplexAttr>(cirAttr)) {
704+ auto vecType = mlir::dyn_cast<mlir::VectorType>(mlirType);
705+ assert (vecType && " complex attribute lowered type should be a vector" );
706+ SmallVector<mlir::Attribute, 2 > elements{
707+ this ->lowerCirAttrToMlirAttr (complexAttr.getReal (), rewriter),
708+ this ->lowerCirAttrToMlirAttr (complexAttr.getImag (), rewriter)};
709+ return mlir::DenseElementsAttr::get (vecType, elements);
700710 } else if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(cirAttr)) {
701711 return rewriter.getIntegerAttr (mlirType, boolAttr.getValue ());
702712 } else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(cirAttr)) {
@@ -1133,18 +1143,30 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
11331143 initialValue = init.value ();
11341144 else
11351145 llvm_unreachable (" GlobalOp lowering array with initial value fail" );
1136- } else if (auto constArr = mlir::dyn_cast<cir::ZeroAttr>(init.value ())) {
1146+ } else if (auto constComplex =
1147+ mlir::dyn_cast<cir::ComplexAttr>(init.value ())) {
1148+ if (auto lowered =
1149+ cir::direct::lowerConstComplexAttr (constComplex,
1150+ getTypeConverter ());
1151+ lowered.has_value ())
1152+ initialValue = lowered.value ();
1153+ else
1154+ llvm_unreachable (
1155+ " GlobalOp lowering complex with initial value failed" );
1156+ } else if (auto zeroAttr = mlir::dyn_cast<cir::ZeroAttr>(init.value ())) {
1157+ (void )zeroAttr;
11371158 if (memrefType.getShape ().size ()) {
11381159 auto elementType = memrefType.getElementType ();
11391160 auto rtt =
11401161 mlir::RankedTensorType::get (memrefType.getShape (), elementType);
11411162 if (mlir::isa<mlir::IntegerType>(elementType))
11421163 initialValue = mlir::DenseIntElementsAttr::get (rtt, 0 );
11431164 else if (mlir::isa<mlir::FloatType>(elementType)) {
1144- auto floatZero = mlir::FloatAttr::get (elementType, 0.0 ).getValue ();
1165+ auto floatZero =
1166+ mlir::FloatAttr::get (elementType, 0.0 ).getValue ();
11451167 initialValue = mlir::DenseFPElementsAttr::get (rtt, floatZero);
11461168 } else
1147- llvm_unreachable ( " GlobalOp lowering unsuppored element type " );
1169+ initialValue = mlir::Attribute ( );
11481170 } else {
11491171 auto rtt = mlir::RankedTensorType::get ({}, convertedType);
11501172 if (mlir::isa<mlir::IntegerType>(convertedType))
@@ -1154,7 +1176,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
11541176 mlir::FloatAttr::get (convertedType, 0.0 ).getValue ();
11551177 initialValue = mlir::DenseFPElementsAttr::get (rtt, floatZero);
11561178 } else
1157- llvm_unreachable ( " GlobalOp lowering unsuppored type " );
1179+ initialValue = mlir::Attribute ( );
11581180 }
11591181 } else if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(init.value ())) {
11601182 auto rtt = mlir::RankedTensorType::get ({}, convertedType);
@@ -1207,6 +1229,67 @@ class CIRGetGlobalOpLowering
12071229 }
12081230};
12091231
1232+ class CIRComplexCreateOpLowering
1233+ : public mlir::OpConversionPattern<cir::ComplexCreateOp> {
1234+ public:
1235+ using OpConversionPattern<cir::ComplexCreateOp>::OpConversionPattern;
1236+
1237+ mlir::LogicalResult
1238+ matchAndRewrite (cir::ComplexCreateOp op, OpAdaptor adaptor,
1239+ mlir::ConversionPatternRewriter &rewriter) const override {
1240+ auto loc = op.getLoc ();
1241+ auto vecType =
1242+ mlir::cast<mlir::VectorType>(getTypeConverter ()->convertType (
1243+ op.getType ()));
1244+ auto zeroAttr = rewriter.getZeroAttr (vecType);
1245+ mlir::Value result =
1246+ rewriter.create <mlir::arith::ConstantOp>(loc, vecType, zeroAttr)
1247+ .getResult ();
1248+ SmallVector<int64_t , 1 > realIdx{0 };
1249+ SmallVector<int64_t , 1 > imagIdx{1 };
1250+ result = rewriter
1251+ .create <mlir::vector::InsertOp>(loc, adaptor.getReal (), result,
1252+ realIdx)
1253+ .getResult ();
1254+ result = rewriter
1255+ .create <mlir::vector::InsertOp>(loc, adaptor.getImag (), result,
1256+ imagIdx)
1257+ .getResult ();
1258+ rewriter.replaceOp (op, result);
1259+ return mlir::success ();
1260+ }
1261+ };
1262+
1263+ class CIRComplexRealOpLowering
1264+ : public mlir::OpConversionPattern<cir::ComplexRealOp> {
1265+ public:
1266+ using OpConversionPattern<cir::ComplexRealOp>::OpConversionPattern;
1267+
1268+ mlir::LogicalResult
1269+ matchAndRewrite (cir::ComplexRealOp op, OpAdaptor adaptor,
1270+ mlir::ConversionPatternRewriter &rewriter) const override {
1271+ SmallVector<int64_t , 1 > idx{0 };
1272+ rewriter.replaceOpWithNewOp <mlir::vector::ExtractOp>(
1273+ op, adaptor.getOperand (), idx);
1274+ return mlir::success ();
1275+ }
1276+ };
1277+
1278+ class CIRComplexImagOpLowering
1279+ : public mlir::OpConversionPattern<cir::ComplexImagOp> {
1280+ public:
1281+ using OpConversionPattern<cir::ComplexImagOp>::OpConversionPattern;
1282+
1283+ mlir::LogicalResult
1284+ matchAndRewrite (cir::ComplexImagOp op, OpAdaptor adaptor,
1285+ mlir::ConversionPatternRewriter &rewriter) const override {
1286+ SmallVector<int64_t , 1 > idx{1 };
1287+ rewriter.replaceOpWithNewOp <mlir::vector::ExtractOp>(
1288+ op, adaptor.getOperand (), idx);
1289+ return mlir::success ();
1290+ }
1291+ };
1292+
12101293class CIRVectorCreateLowering
12111294 : public mlir::OpConversionPattern<cir::VecCreateOp> {
12121295public:
@@ -1601,12 +1684,13 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
16011684 CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
16021685 CIRFuncOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
16031686 CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
1604- CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering,
1605- CIRGetElementOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
1606- CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1607- CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
1608- CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering,
1609- CIRSinOpLowering, CIRTanOpLowering, CIRShiftOpLowering,
1687+ CIRGetGlobalOpLowering, CIRComplexCreateOpLowering,
1688+ CIRComplexRealOpLowering, CIRComplexImagOpLowering, CIRCastOpLowering,
1689+ CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
1690+ CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1691+ CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1692+ CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1693+ CIRRoundOpLowering, CIRSinOpLowering, CIRTanOpLowering, CIRShiftOpLowering,
16101694 CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
16111695 CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
16121696 CIRIfOpLowering, CIRScopeOpLowering, CIRVectorCreateLowering,
@@ -1679,6 +1763,12 @@ static mlir::TypeConverter prepareTypeConverter() {
16791763 auto ty = converter.convertType (type.getElementType ());
16801764 return mlir::VectorType::get (type.getSize (), ty);
16811765 });
1766+ converter.addConversion ([&](cir::ComplexType type) -> mlir::Type {
1767+ auto elemTy = converter.convertType (type.getElementType ());
1768+ if (!elemTy)
1769+ return nullptr ;
1770+ return mlir::VectorType::get (2 , elemTy);
1771+ });
16821772 return converter;
16831773}
16841774
0 commit comments