diff --git a/clang/include/clang/CIR/LoweringHelpers.h b/clang/include/clang/CIR/LoweringHelpers.h index 32f9f3b3a98b..483fde742934 100644 --- a/clang/include/clang/CIR/LoweringHelpers.h +++ b/clang/include/clang/CIR/LoweringHelpers.h @@ -50,4 +50,9 @@ lowerConstArrayAttr(cir::ConstArrayAttr constArr, std::optional lowerConstVectorAttr(cir::ConstVectorAttr constArr, const mlir::TypeConverter *converter); + +std::optional +lowerConstComplexAttr(cir::ComplexAttr constArr, + const mlir::TypeConverter *converter); + #endif diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConst.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConst.cpp index 1c34651318d3..426f6eb0411d 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprConst.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprConst.cpp @@ -2007,9 +2007,30 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &Value, case APValue::Struct: case APValue::Union: return ConstRecordBuilder::BuildRecord(*this, Value, DestType); - case APValue::FixedPoint: - case APValue::ComplexInt: case APValue::ComplexFloat: + case APValue::ComplexInt: { + mlir::Type desiredType = CGM.convertType(DestType); + cir::ComplexType complexType = + mlir::dyn_cast(desiredType); + + mlir::Type complexElemTy = complexType.getElementType(); + if (isa(complexElemTy)) { + llvm::APSInt real = Value.getComplexIntReal(); + llvm::APSInt imag = Value.getComplexIntImag(); + return builder.getAttr( + complexType, builder.getAttr(complexElemTy, real), + builder.getAttr(complexElemTy, imag)); + } + + assert(isa(complexElemTy) && + "expected floating-point type"); + llvm::APFloat real = Value.getComplexFloatReal(); + llvm::APFloat imag = Value.getComplexFloatImag(); + return builder.getAttr( + complexType, builder.getAttr(complexElemTy, real), + builder.getAttr(complexElemTy, imag)); + } + case APValue::FixedPoint: case APValue::AddrLabelDiff: assert(0 && "not implemented"); } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 66dce5d80d08..604f5aba80ca 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -452,8 +452,8 @@ class CirAttrToValue { .Case( - [&](auto attrT) { return visitCirAttr(attrT); }) + cir::GlobalViewAttr, cir::VTableAttr, cir::TypeInfoAttr, + cir::ComplexAttr>([&](auto attrT) { return visitCirAttr(attrT); }) .Default([&](auto attrT) { return mlir::Value(); }); } @@ -463,6 +463,7 @@ class CirAttrToValue { mlir::Value visitCirAttr(cir::ConstRecordAttr attr); mlir::Value visitCirAttr(cir::ConstArrayAttr attr); mlir::Value visitCirAttr(cir::ConstVectorAttr attr); + mlir::Value visitCirAttr(cir::ComplexAttr attr); mlir::Value visitCirAttr(cir::BoolAttr attr); mlir::Value visitCirAttr(cir::ZeroAttr attr); mlir::Value visitCirAttr(cir::UndefAttr attr); @@ -647,6 +648,33 @@ mlir::Value CirAttrToValue::visitCirAttr(cir::ConstVectorAttr constVec) { mlirValues)); } +mlir::Value CirAttrToValue::visitCirAttr(cir::ComplexAttr complexAttr) { + auto complexType = mlir::cast(complexAttr.getType()); + mlir::Type complexElemTy = complexType.getElementType(); + mlir::Type complexElemLLVMTy = converter->convertType(complexElemTy); + + mlir::Attribute components[2]; + if (const auto intType = mlir::dyn_cast(complexElemTy)) { + components[0] = rewriter.getIntegerAttr( + complexElemLLVMTy, + mlir::cast(complexAttr.getReal()).getValue()); + components[1] = rewriter.getIntegerAttr( + complexElemLLVMTy, + mlir::cast(complexAttr.getImag()).getValue()); + } else { + components[0] = rewriter.getFloatAttr( + complexElemLLVMTy, + mlir::cast(complexAttr.getReal()).getValue()); + components[1] = rewriter.getFloatAttr( + complexElemLLVMTy, + mlir::cast(complexAttr.getImag()).getValue()); + } + + mlir::Location loc = parentOp->getLoc(); + return rewriter.create( + loc, converter->convertType(complexAttr.getType()), + rewriter.getArrayAttr(components)); +} // GlobalViewAttr visitor. mlir::Value CirAttrToValue::visitCirAttr(cir::GlobalViewAttr globalAttr) { auto module = parentOp->getParentOfType(); @@ -2428,6 +2456,9 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializer( } else if (mlir::isa(init)) { return lowerInitializerForConstVector(rewriter, op, init, useInitializerRegion); + } else if (mlir::isa(init)) { + return lowerInitializerForConstComplex(rewriter, op, init, + useInitializerRegion); } else if (auto dataMemberAttr = mlir::dyn_cast(init)) { assert(lowerMod && "lower module is not available"); mlir::DataLayout layout(op->getParentOfType()); @@ -2464,6 +2495,19 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstVector( return mlir::failure(); } +mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstComplex( + mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op, + mlir::Attribute &init, bool &useInitializerRegion) const { + auto constVec = mlir::cast(init); + if (auto val = lowerConstComplexAttr(constVec, getTypeConverter()); + val.has_value()) { + init = val.value(); + useInitializerRegion = false; + } else + useInitializerRegion = true; + return mlir::success(); +} + mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstArray( mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op, mlir::Attribute &init, bool &useInitializerRegion) const { diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 74edf58b347a..ad3fa1983c38 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -628,6 +628,11 @@ class CIRToLLVMGlobalOpLowering cir::GlobalOp op, mlir::Attribute &init, bool &useInitializerRegion) const; + mlir::LogicalResult + lowerInitializerForConstComplex(mlir::ConversionPatternRewriter &rewriter, + cir::GlobalOp op, mlir::Attribute &init, + bool &useInitializerRegion) const; + mlir::LogicalResult lowerInitializerDirect(mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op, mlir::Type llvmType, diff --git a/clang/lib/CIR/Lowering/LoweringHelpers.cpp b/clang/lib/CIR/Lowering/LoweringHelpers.cpp index 1cfcc966e8a8..0dd209fdb6b9 100644 --- a/clang/lib/CIR/Lowering/LoweringHelpers.cpp +++ b/clang/lib/CIR/Lowering/LoweringHelpers.cpp @@ -126,6 +126,33 @@ void convertToDenseElementsAttrImpl( } } +template +void convertToDenseElementsAttrImpl( + cir::ComplexAttr attr, llvm::SmallVectorImpl &values, + const llvm::SmallVectorImpl ¤tDims, int64_t dimIndex, + int64_t currentIndex) { + dimIndex++; + std::size_t elementsSizeInCurrentDim = 1; + for (std::size_t i = dimIndex; i < currentDims.size(); i++) + elementsSizeInCurrentDim *= currentDims[i]; + + auto attrArray = + mlir::ArrayAttr::get(attr.getContext(), {attr.getImag(), attr.getReal()}); + for (auto eltAttr : attrArray) { + if (auto valueAttr = mlir::dyn_cast(eltAttr)) { + values[currentIndex++] = valueAttr.getValue(); + continue; + } + + if (mlir::isa(eltAttr)) { + currentIndex += elementsSizeInCurrentDim; + continue; + } + + llvm_unreachable("unknown element in ComplexAttr"); + } +} + template mlir::DenseElementsAttr convertToDenseElementsAttr( cir::ConstArrayAttr attr, const llvm::SmallVectorImpl &dims, @@ -158,6 +185,20 @@ mlir::DenseElementsAttr convertToDenseElementsAttr( llvm::ArrayRef(values)); } +template +mlir::DenseElementsAttr convertToDenseElementsAttr( + cir::ComplexAttr attr, const llvm::SmallVectorImpl &dims, + mlir::Type elementType, mlir::Type convertedElementType) { + unsigned array_size = 2; + auto values = llvm::SmallVector( + array_size, getZeroInitFromType(elementType)); + convertToDenseElementsAttrImpl(attr, values, dims, /*currentDim=*/0, + /*initialIndex=*/0); + return mlir::DenseElementsAttr::get( + mlir::RankedTensorType::get(dims, convertedElementType), + llvm::ArrayRef(values)); +} + std::optional lowerConstArrayAttr(cir::ConstArrayAttr constArr, const mlir::TypeConverter *converter) { @@ -191,6 +232,27 @@ lowerConstArrayAttr(cir::ConstArrayAttr constArr, return std::nullopt; } +std::optional +lowerConstComplexAttr(cir::ComplexAttr constComplex, + const mlir::TypeConverter *converter) { + + // Ensure ComplexAttr has a type. + auto typedConstArr = mlir::dyn_cast(constComplex); + assert(typedConstArr && "cir::ComplexAttr is not a mlir::TypedAttr"); + + mlir::Type type = constComplex.getType(); + auto dims = llvm::SmallVector{2}; + + if (mlir::isa(type)) + return convertToDenseElementsAttr( + constComplex, dims, type, converter->convertType(type)); + if (mlir::isa(type)) + return convertToDenseElementsAttr( + constComplex, dims, type, converter->convertType(type)); + + return std::nullopt; +} + std::optional lowerConstVectorAttr(cir::ConstVectorAttr constArr, const mlir::TypeConverter *converter) { diff --git a/clang/test/CIR/CodeGen/const-complex.cpp b/clang/test/CIR/CodeGen/const-complex.cpp new file mode 100644 index 000000000000..76b4f8a0284d --- /dev/null +++ b/clang/test/CIR/CodeGen/const-complex.cpp @@ -0,0 +1,22 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir +// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CHECK +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll +// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM + +int _Complex gci; + +float _Complex gcf; + +int _Complex gci2 = { 1, 2 }; + +float _Complex gcf2 = { 1.0f, 2.0f }; + +// CHECK: cir.global external {{.*}} = #cir.zero : !cir.complex +// CHECK: cir.global external {{.*}} = #cir.zero : !cir.complex +// CHECK: cir.global external {{.*}} = #cir.complex<#cir.int<1> : !s32i, #cir.int<2> : !s32i> : !cir.complex +// CHECK: cir.global external {{.*}} = #cir.complex<#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00> : !cir.float> : !cir.complex + +// LLVM: {{.*}} = global { i32, i32 } zeroinitializer, align 4 +// LLVM: {{.*}} = global { float, float } zeroinitializer, align 4 +// LLVM: {{.*}} = global { i32, i32 } { i32 1, i32 2 }, align 4 +// LLVM: {{.*}} = global { float, float } { float 1.000000e+00, float 2.000000e+00 }, align 4 \ No newline at end of file