Skip to content

[CIR] Backport support for global ComplexType init #1665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions clang/include/clang/CIR/LoweringHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,9 @@ lowerConstArrayAttr(cir::ConstArrayAttr constArr,
std::optional<mlir::Attribute>
lowerConstVectorAttr(cir::ConstVectorAttr constArr,
const mlir::TypeConverter *converter);

std::optional<mlir::Attribute>
lowerConstComplexAttr(cir::ComplexAttr constArr,
const mlir::TypeConverter *converter);

#endif
25 changes: 23 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenExprConst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2007,9 +2007,30 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &Value,
case APValue::Struct:
case APValue::Union:
return ConstRecordBuilder::BuildRecord(*this, Value, DestType);
case APValue::FixedPoint:
case APValue::ComplexInt:
case APValue::ComplexFloat:
case APValue::ComplexInt: {
mlir::Type desiredType = CGM.convertType(DestType);
cir::ComplexType complexType =
mlir::dyn_cast<cir::ComplexType>(desiredType);

mlir::Type complexElemTy = complexType.getElementType();
if (isa<cir::IntType>(complexElemTy)) {
llvm::APSInt real = Value.getComplexIntReal();
llvm::APSInt imag = Value.getComplexIntImag();
return builder.getAttr<cir::ComplexAttr>(
complexType, builder.getAttr<cir::IntAttr>(complexElemTy, real),
builder.getAttr<cir::IntAttr>(complexElemTy, imag));
}

assert(isa<cir::CIRFPTypeInterface>(complexElemTy) &&
"expected floating-point type");
llvm::APFloat real = Value.getComplexFloatReal();
llvm::APFloat imag = Value.getComplexFloatImag();
return builder.getAttr<cir::ComplexAttr>(
complexType, builder.getAttr<cir::FPAttr>(complexElemTy, real),
builder.getAttr<cir::FPAttr>(complexElemTy, imag));
}
case APValue::FixedPoint:
case APValue::AddrLabelDiff:
assert(0 && "not implemented");
}
Expand Down
48 changes: 46 additions & 2 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,8 @@ class CirAttrToValue {
.Case<cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr,
cir::ConstRecordAttr, cir::ConstArrayAttr, cir::ConstVectorAttr,
cir::BoolAttr, cir::ZeroAttr, cir::UndefAttr, cir::PoisonAttr,
cir::GlobalViewAttr, cir::VTableAttr, cir::TypeInfoAttr>(
[&](auto attrT) { return visitCirAttr(attrT); })
cir::GlobalViewAttr, cir::VTableAttr, cir::TypeInfoAttr,
cir::ComplexAttr>([&](auto attrT) { return visitCirAttr(attrT); })
.Default([&](auto attrT) { return mlir::Value(); });
}

Expand All @@ -463,6 +463,7 @@ class CirAttrToValue {
mlir::Value visitCirAttr(cir::ConstRecordAttr attr);
mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
mlir::Value visitCirAttr(cir::ComplexAttr attr);
mlir::Value visitCirAttr(cir::BoolAttr attr);
mlir::Value visitCirAttr(cir::ZeroAttr attr);
mlir::Value visitCirAttr(cir::UndefAttr attr);
Expand Down Expand Up @@ -647,6 +648,33 @@ mlir::Value CirAttrToValue::visitCirAttr(cir::ConstVectorAttr constVec) {
mlirValues));
}

mlir::Value CirAttrToValue::visitCirAttr(cir::ComplexAttr complexAttr) {
auto complexType = mlir::cast<cir::ComplexType>(complexAttr.getType());
mlir::Type complexElemTy = complexType.getElementType();
mlir::Type complexElemLLVMTy = converter->convertType(complexElemTy);

mlir::Attribute components[2];
if (const auto intType = mlir::dyn_cast<cir::IntType>(complexElemTy)) {
components[0] = rewriter.getIntegerAttr(
complexElemLLVMTy,
mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue());
components[1] = rewriter.getIntegerAttr(
complexElemLLVMTy,
mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue());
} else {
components[0] = rewriter.getFloatAttr(
complexElemLLVMTy,
mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue());
components[1] = rewriter.getFloatAttr(
complexElemLLVMTy,
mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue());
}

mlir::Location loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ConstantOp>(
loc, converter->convertType(complexAttr.getType()),
rewriter.getArrayAttr(components));
}
// GlobalViewAttr visitor.
mlir::Value CirAttrToValue::visitCirAttr(cir::GlobalViewAttr globalAttr) {
auto module = parentOp->getParentOfType<mlir::ModuleOp>();
Expand Down Expand Up @@ -2428,6 +2456,9 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializer(
} else if (mlir::isa<cir::ConstVectorAttr>(init)) {
return lowerInitializerForConstVector(rewriter, op, init,
useInitializerRegion);
} else if (mlir::isa<cir::ComplexAttr>(init)) {
return lowerInitializerForConstComplex(rewriter, op, init,
useInitializerRegion);
} else if (auto dataMemberAttr = mlir::dyn_cast<cir::DataMemberAttr>(init)) {
assert(lowerMod && "lower module is not available");
mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
Expand Down Expand Up @@ -2464,6 +2495,19 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstVector(
return mlir::failure();
}

mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstComplex(
mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op,
mlir::Attribute &init, bool &useInitializerRegion) const {
auto constVec = mlir::cast<cir::ComplexAttr>(init);
if (auto val = lowerConstComplexAttr(constVec, getTypeConverter());
val.has_value()) {
init = val.value();
useInitializerRegion = false;
} else
useInitializerRegion = true;
return mlir::success();
}

mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstArray(
mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op,
mlir::Attribute &init, bool &useInitializerRegion) const {
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,11 @@ class CIRToLLVMGlobalOpLowering
cir::GlobalOp op, mlir::Attribute &init,
bool &useInitializerRegion) const;

mlir::LogicalResult
lowerInitializerForConstComplex(mlir::ConversionPatternRewriter &rewriter,
cir::GlobalOp op, mlir::Attribute &init,
bool &useInitializerRegion) const;

mlir::LogicalResult
lowerInitializerDirect(mlir::ConversionPatternRewriter &rewriter,
cir::GlobalOp op, mlir::Type llvmType,
Expand Down
62 changes: 62 additions & 0 deletions clang/lib/CIR/Lowering/LoweringHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,33 @@ void convertToDenseElementsAttrImpl(
}
}

template <typename AttrTy, typename StorageTy>
void convertToDenseElementsAttrImpl(
cir::ComplexAttr attr, llvm::SmallVectorImpl<StorageTy> &values,
const llvm::SmallVectorImpl<int64_t> &currentDims, int64_t dimIndex,
int64_t currentIndex) {
dimIndex++;
std::size_t elementsSizeInCurrentDim = 1;
for (std::size_t i = dimIndex; i < currentDims.size(); i++)
elementsSizeInCurrentDim *= currentDims[i];

auto attrArray =
mlir::ArrayAttr::get(attr.getContext(), {attr.getImag(), attr.getReal()});
for (auto eltAttr : attrArray) {
if (auto valueAttr = mlir::dyn_cast<AttrTy>(eltAttr)) {
values[currentIndex++] = valueAttr.getValue();
continue;
}

if (mlir::isa<cir::ZeroAttr, cir::UndefAttr>(eltAttr)) {
currentIndex += elementsSizeInCurrentDim;
continue;
}

llvm_unreachable("unknown element in ComplexAttr");
}
}

template <typename AttrTy, typename StorageTy>
mlir::DenseElementsAttr convertToDenseElementsAttr(
cir::ConstArrayAttr attr, const llvm::SmallVectorImpl<int64_t> &dims,
Expand Down Expand Up @@ -158,6 +185,20 @@ mlir::DenseElementsAttr convertToDenseElementsAttr(
llvm::ArrayRef(values));
}

template <typename AttrTy, typename StorageTy>
mlir::DenseElementsAttr convertToDenseElementsAttr(
cir::ComplexAttr attr, const llvm::SmallVectorImpl<int64_t> &dims,
mlir::Type elementType, mlir::Type convertedElementType) {
unsigned array_size = 2;
auto values = llvm::SmallVector<StorageTy, 8>(
array_size, getZeroInitFromType<StorageTy>(elementType));
convertToDenseElementsAttrImpl<AttrTy>(attr, values, dims, /*currentDim=*/0,
/*initialIndex=*/0);
return mlir::DenseElementsAttr::get(
mlir::RankedTensorType::get(dims, convertedElementType),
llvm::ArrayRef(values));
}

std::optional<mlir::Attribute>
lowerConstArrayAttr(cir::ConstArrayAttr constArr,
const mlir::TypeConverter *converter) {
Expand Down Expand Up @@ -191,6 +232,27 @@ lowerConstArrayAttr(cir::ConstArrayAttr constArr,
return std::nullopt;
}

std::optional<mlir::Attribute>
lowerConstComplexAttr(cir::ComplexAttr constComplex,
const mlir::TypeConverter *converter) {

// Ensure ComplexAttr has a type.
auto typedConstArr = mlir::dyn_cast<mlir::TypedAttr>(constComplex);
assert(typedConstArr && "cir::ComplexAttr is not a mlir::TypedAttr");

mlir::Type type = constComplex.getType();
auto dims = llvm::SmallVector<int64_t, 2>{2};

if (mlir::isa<cir::IntType>(type))
return convertToDenseElementsAttr<cir::IntAttr, mlir::APInt>(
constComplex, dims, type, converter->convertType(type));
if (mlir::isa<cir::CIRFPTypeInterface>(type))
return convertToDenseElementsAttr<cir::FPAttr, mlir::APFloat>(
constComplex, dims, type, converter->convertType(type));

return std::nullopt;
}

std::optional<mlir::Attribute>
lowerConstVectorAttr(cir::ConstVectorAttr constArr,
const mlir::TypeConverter *converter) {
Expand Down
22 changes: 22 additions & 0 deletions clang/test/CIR/CodeGen/const-complex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CHECK
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM

int _Complex gci;

float _Complex gcf;

int _Complex gci2 = { 1, 2 };

float _Complex gcf2 = { 1.0f, 2.0f };

// CHECK: cir.global external {{.*}} = #cir.zero : !cir.complex<!s32i>
// CHECK: cir.global external {{.*}} = #cir.zero : !cir.complex<!cir.float>
// CHECK: cir.global external {{.*}} = #cir.complex<#cir.int<1> : !s32i, #cir.int<2> : !s32i> : !cir.complex<!s32i>
// CHECK: cir.global external {{.*}} = #cir.complex<#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00> : !cir.float> : !cir.complex<!cir.float>

// LLVM: {{.*}} = global { i32, i32 } zeroinitializer, align 4
// LLVM: {{.*}} = global { float, float } zeroinitializer, align 4
// LLVM: {{.*}} = global { i32, i32 } { i32 1, i32 2 }, align 4
// LLVM: {{.*}} = global { float, float } { float 1.000000e+00, float 2.000000e+00 }, align 4
Loading