Skip to content

Commit cf6923e

Browse files
AmrDeveloperlanza
authored andcommitted
[CIR] Backport global initializer for ComplexType (#1956)
Backport global initializer for ComplexType from the upstream
1 parent 5b631e9 commit cf6923e

File tree

7 files changed

+226
-45
lines changed

7 files changed

+226
-45
lines changed

clang/include/clang/CIR/LoweringHelpers.h

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#ifndef LLVM_CLANG_CIR_LOWERINGHELPERS_H
1313
#define LLVM_CLANG_CIR_LOWERINGHELPERS_H
1414

15+
#include <cstdint>
1516
#include <optional>
1617

1718
#include "llvm/ADT/SmallVector.h"
@@ -34,20 +35,37 @@ template <> mlir::APFloat getZeroInitFromType(mlir::Type Ty);
3435
mlir::Type getNestedTypeAndElemQuantity(mlir::Type Ty, unsigned &elemQuantity);
3536

3637
template <typename AttrTy, typename StorageTy>
37-
void convertToDenseElementsAttrImpl(cir::ConstArrayAttr attr,
38-
llvm::SmallVectorImpl<StorageTy> &values);
38+
void convertToDenseElementsAttrImpl(
39+
cir::ConstArrayAttr attr, llvm::SmallVectorImpl<StorageTy> &values,
40+
const llvm::SmallVectorImpl<int64_t> &currentDims, int64_t dimIndex,
41+
int64_t currentIndex);
3942

4043
template <typename AttrTy, typename StorageTy>
41-
mlir::DenseElementsAttr
42-
convertToDenseElementsAttr(cir::ConstArrayAttr attr,
43-
const llvm::SmallVectorImpl<int64_t> &dims,
44-
mlir::Type type);
44+
void convertToDenseElementsAttrImpl(
45+
cir::ConstVectorAttr attr, llvm::SmallVectorImpl<StorageTy> &values,
46+
const llvm::SmallVectorImpl<int64_t> &currentDims, int64_t dimIndex,
47+
int64_t currentIndex);
4548

4649
template <typename AttrTy, typename StorageTy>
47-
mlir::DenseElementsAttr
48-
convertToDenseElementsAttr(cir::ConstVectorAttr attr,
49-
const llvm::SmallVectorImpl<int64_t> &dims,
50-
mlir::Type type);
50+
void convertToDenseElementsAttrImpl(
51+
cir::ComplexAttr attr, llvm::SmallVectorImpl<StorageTy> &values,
52+
const llvm::SmallVectorImpl<int64_t> &currentDims, int64_t dimIndex,
53+
int64_t currentIndex);
54+
55+
template <typename AttrTy, typename StorageTy>
56+
mlir::DenseElementsAttr convertToDenseElementsAttr(
57+
cir::ConstArrayAttr attr, const llvm::SmallVectorImpl<int64_t> &dims,
58+
mlir::Type elementType, mlir::Type convertedElementType);
59+
60+
template <typename AttrTy, typename StorageTy>
61+
mlir::DenseElementsAttr convertToDenseElementsAttr(
62+
cir::ConstVectorAttr attr, const llvm::SmallVectorImpl<int64_t> &dims,
63+
mlir::Type elementType, mlir::Type convertedElementType);
64+
65+
template <typename AttrTy, typename StorageTy>
66+
mlir::DenseElementsAttr convertToDenseElementsAttr(
67+
cir::ComplexAttr attr, const llvm::SmallVectorImpl<int64_t> &dims,
68+
mlir::Type elementType, mlir::Type convertedElementType);
5169

5270
std::optional<mlir::Attribute>
5371
lowerConstArrayAttr(cir::ConstArrayAttr constArr,

clang/lib/CIR/CodeGen/CIRGenCXX.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ static void emitDeclInit(CIRGenFunction &cgf, const VarDecl *varDecl,
204204
cgf.emitScalarInit(init, cgf.getLoc(varDecl->getLocation()), lv, false);
205205
return;
206206
case cir::TEK_Complex:
207-
llvm_unreachable("complext evaluation NYI");
207+
cgf.emitComplexExprIntoLValue(init, lv, /*isInit=*/true);
208+
return;
208209
case cir::TEK_Aggregate:
209210
cgf.emitAggExpr(init,
210211
AggValueSlot::forLValue(lv, AggValueSlot::IsDestructed,

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2738,7 +2738,22 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstComplex(
27382738
auto constComplex = mlir::cast<cir::ComplexAttr>(init);
27392739
if (auto val = lowerConstComplexAttr(constComplex, getTypeConverter());
27402740
val.has_value()) {
2741-
init = val.value();
2741+
auto loweredAttr = val.value();
2742+
if (auto dense = mlir::dyn_cast<mlir::DenseElementsAttr>(loweredAttr)) {
2743+
llvm::SmallVector<mlir::Attribute, 2> components;
2744+
components.reserve(dense.getNumElements());
2745+
auto elementType = dense.getElementType();
2746+
if (mlir::isa<mlir::IntegerType>(elementType)) {
2747+
for (auto value : dense.getValues<mlir::APInt>())
2748+
components.push_back(mlir::IntegerAttr::get(elementType, value));
2749+
} else if (mlir::isa<mlir::FloatType>(elementType)) {
2750+
for (auto value : dense.getValues<mlir::APFloat>())
2751+
components.push_back(mlir::FloatAttr::get(elementType, value));
2752+
}
2753+
if (!components.empty())
2754+
loweredAttr = mlir::ArrayAttr::get(rewriter.getContext(), components);
2755+
}
2756+
init = loweredAttr;
27422757
useInitializerRegion = false;
27432758
} else
27442759
useInitializerRegion = true;

clang/lib/CIR/Lowering/LoweringHelpers.cpp

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ void convertToDenseElementsAttrImpl(
145145
elementsSizeInCurrentDim *= currentDims[i];
146146

147147
auto attrArray =
148-
mlir::ArrayAttr::get(attr.getContext(), {attr.getImag(), attr.getReal()});
148+
mlir::ArrayAttr::get(attr.getContext(), {attr.getReal(), attr.getImag()});
149149
for (auto eltAttr : attrArray) {
150150
if (auto valueAttr = mlir::dyn_cast<AttrTy>(eltAttr)) {
151151
values[currentIndex++] = valueAttr.getValue();
@@ -206,9 +206,13 @@ mlir::DenseElementsAttr convertToDenseElementsAttr(
206206
array_size, getZeroInitFromType<StorageTy>(elementType));
207207
convertToDenseElementsAttrImpl<AttrTy>(attr, values, dims, /*currentDim=*/0,
208208
/*initialIndex=*/0);
209-
return mlir::DenseElementsAttr::get(
210-
mlir::RankedTensorType::get(dims, convertedElementType),
211-
llvm::ArrayRef(values));
209+
mlir::ShapedType shapedType;
210+
if (dims.size() == 1) {
211+
shapedType = mlir::VectorType::get(dims, convertedElementType);
212+
} else {
213+
shapedType = mlir::RankedTensorType::get(dims, convertedElementType);
214+
}
215+
return mlir::DenseElementsAttr::get(shapedType, llvm::ArrayRef(values));
212216
}
213217

214218
std::optional<mlir::Attribute>
@@ -259,28 +263,20 @@ lowerConstComplexAttr(cir::ComplexAttr constComplex,
259263
if (!convertedElementType)
260264
return std::nullopt;
261265

262-
llvm::SmallVector<mlir::Attribute, 2> components;
263-
components.reserve(2);
264-
265-
if (auto cirIntTy = mlir::dyn_cast<cir::IntType>(elementType)) {
266-
(void)cirIntTy;
267-
auto real = mlir::cast<cir::IntAttr>(constComplex.getReal());
268-
auto imag = mlir::cast<cir::IntAttr>(constComplex.getImag());
269-
components.push_back(
270-
mlir::IntegerAttr::get(convertedElementType, real.getValue()));
271-
components.push_back(
272-
mlir::IntegerAttr::get(convertedElementType, imag.getValue()));
273-
return mlir::ArrayAttr::get(constComplex.getContext(), components);
266+
llvm::SmallVector<int64_t, 1> dims{2};
267+
268+
if (mlir::isa<cir::IntType>(elementType)) {
269+
auto dense =
270+
convertToDenseElementsAttr<cir::IntAttr, mlir::APInt>(
271+
constComplex, dims, elementType, convertedElementType);
272+
return std::optional<mlir::Attribute>(dense);
274273
}
275274

276275
if (mlir::isa<cir::FPTypeInterface>(elementType)) {
277-
auto real = mlir::cast<cir::FPAttr>(constComplex.getReal());
278-
auto imag = mlir::cast<cir::FPAttr>(constComplex.getImag());
279-
components.push_back(
280-
mlir::FloatAttr::get(convertedElementType, real.getValue()));
281-
components.push_back(
282-
mlir::FloatAttr::get(convertedElementType, imag.getValue()));
283-
return mlir::ArrayAttr::get(constComplex.getContext(), components);
276+
auto dense =
277+
convertToDenseElementsAttr<cir::FPAttr, mlir::APFloat>(
278+
constComplex, dims, elementType, convertedElementType);
279+
return std::optional<mlir::Attribute>(dense);
284280
}
285281

286282
return std::nullopt;

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 100 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,16 @@ class CIRConstantOpLowering
697697
}
698698
return mlir::DenseElementsAttr::get(
699699
mlir::cast<mlir::ShapedType>(mlirType), mlirValues);
700+
} else if (auto zeroAttr = mlir::dyn_cast<cir::ZeroAttr>(cirAttr)) {
701+
(void)zeroAttr;
702+
return rewriter.getZeroAttr(mlirType);
703+
} else if (auto complexAttr = mlir::dyn_cast<cir::ComplexAttr>(cirAttr)) {
704+
auto vecType = mlir::dyn_cast<mlir::VectorType>(mlirType);
705+
assert(vecType && "complex attribute lowered type should be a vector");
706+
SmallVector<mlir::Attribute, 2> elements{
707+
this->lowerCirAttrToMlirAttr(complexAttr.getReal(), rewriter),
708+
this->lowerCirAttrToMlirAttr(complexAttr.getImag(), rewriter)};
709+
return mlir::DenseElementsAttr::get(vecType, elements);
700710
} else if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(cirAttr)) {
701711
return rewriter.getIntegerAttr(mlirType, boolAttr.getValue());
702712
} else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(cirAttr)) {
@@ -1133,18 +1143,30 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
11331143
initialValue = init.value();
11341144
else
11351145
llvm_unreachable("GlobalOp lowering array with initial value fail");
1136-
} else if (auto constArr = mlir::dyn_cast<cir::ZeroAttr>(init.value())) {
1146+
} else if (auto constComplex =
1147+
mlir::dyn_cast<cir::ComplexAttr>(init.value())) {
1148+
if (auto lowered =
1149+
cir::direct::lowerConstComplexAttr(constComplex,
1150+
getTypeConverter());
1151+
lowered.has_value())
1152+
initialValue = lowered.value();
1153+
else
1154+
llvm_unreachable(
1155+
"GlobalOp lowering complex with initial value failed");
1156+
} else if (auto zeroAttr = mlir::dyn_cast<cir::ZeroAttr>(init.value())) {
1157+
(void)zeroAttr;
11371158
if (memrefType.getShape().size()) {
11381159
auto elementType = memrefType.getElementType();
11391160
auto rtt =
11401161
mlir::RankedTensorType::get(memrefType.getShape(), elementType);
11411162
if (mlir::isa<mlir::IntegerType>(elementType))
11421163
initialValue = mlir::DenseIntElementsAttr::get(rtt, 0);
11431164
else if (mlir::isa<mlir::FloatType>(elementType)) {
1144-
auto floatZero = mlir::FloatAttr::get(elementType, 0.0).getValue();
1165+
auto floatZero =
1166+
mlir::FloatAttr::get(elementType, 0.0).getValue();
11451167
initialValue = mlir::DenseFPElementsAttr::get(rtt, floatZero);
11461168
} else
1147-
llvm_unreachable("GlobalOp lowering unsuppored element type");
1169+
initialValue = mlir::Attribute();
11481170
} else {
11491171
auto rtt = mlir::RankedTensorType::get({}, convertedType);
11501172
if (mlir::isa<mlir::IntegerType>(convertedType))
@@ -1154,7 +1176,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
11541176
mlir::FloatAttr::get(convertedType, 0.0).getValue();
11551177
initialValue = mlir::DenseFPElementsAttr::get(rtt, floatZero);
11561178
} else
1157-
llvm_unreachable("GlobalOp lowering unsuppored type");
1179+
initialValue = mlir::Attribute();
11581180
}
11591181
} else if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(init.value())) {
11601182
auto rtt = mlir::RankedTensorType::get({}, convertedType);
@@ -1207,6 +1229,67 @@ class CIRGetGlobalOpLowering
12071229
}
12081230
};
12091231

1232+
class CIRComplexCreateOpLowering
1233+
: public mlir::OpConversionPattern<cir::ComplexCreateOp> {
1234+
public:
1235+
using OpConversionPattern<cir::ComplexCreateOp>::OpConversionPattern;
1236+
1237+
mlir::LogicalResult
1238+
matchAndRewrite(cir::ComplexCreateOp op, OpAdaptor adaptor,
1239+
mlir::ConversionPatternRewriter &rewriter) const override {
1240+
auto loc = op.getLoc();
1241+
auto vecType =
1242+
mlir::cast<mlir::VectorType>(getTypeConverter()->convertType(
1243+
op.getType()));
1244+
auto zeroAttr = rewriter.getZeroAttr(vecType);
1245+
mlir::Value result =
1246+
rewriter.create<mlir::arith::ConstantOp>(loc, vecType, zeroAttr)
1247+
.getResult();
1248+
SmallVector<int64_t, 1> realIdx{0};
1249+
SmallVector<int64_t, 1> imagIdx{1};
1250+
result = rewriter
1251+
.create<mlir::vector::InsertOp>(loc, adaptor.getReal(), result,
1252+
realIdx)
1253+
.getResult();
1254+
result = rewriter
1255+
.create<mlir::vector::InsertOp>(loc, adaptor.getImag(), result,
1256+
imagIdx)
1257+
.getResult();
1258+
rewriter.replaceOp(op, result);
1259+
return mlir::success();
1260+
}
1261+
};
1262+
1263+
class CIRComplexRealOpLowering
1264+
: public mlir::OpConversionPattern<cir::ComplexRealOp> {
1265+
public:
1266+
using OpConversionPattern<cir::ComplexRealOp>::OpConversionPattern;
1267+
1268+
mlir::LogicalResult
1269+
matchAndRewrite(cir::ComplexRealOp op, OpAdaptor adaptor,
1270+
mlir::ConversionPatternRewriter &rewriter) const override {
1271+
SmallVector<int64_t, 1> idx{0};
1272+
rewriter.replaceOpWithNewOp<mlir::vector::ExtractOp>(
1273+
op, adaptor.getOperand(), idx);
1274+
return mlir::success();
1275+
}
1276+
};
1277+
1278+
class CIRComplexImagOpLowering
1279+
: public mlir::OpConversionPattern<cir::ComplexImagOp> {
1280+
public:
1281+
using OpConversionPattern<cir::ComplexImagOp>::OpConversionPattern;
1282+
1283+
mlir::LogicalResult
1284+
matchAndRewrite(cir::ComplexImagOp op, OpAdaptor adaptor,
1285+
mlir::ConversionPatternRewriter &rewriter) const override {
1286+
SmallVector<int64_t, 1> idx{1};
1287+
rewriter.replaceOpWithNewOp<mlir::vector::ExtractOp>(
1288+
op, adaptor.getOperand(), idx);
1289+
return mlir::success();
1290+
}
1291+
};
1292+
12101293
class CIRVectorCreateLowering
12111294
: public mlir::OpConversionPattern<cir::VecCreateOp> {
12121295
public:
@@ -1601,12 +1684,13 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
16011684
CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
16021685
CIRFuncOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
16031686
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
1604-
CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering,
1605-
CIRGetElementOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
1606-
CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1607-
CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
1608-
CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering,
1609-
CIRSinOpLowering, CIRTanOpLowering, CIRShiftOpLowering,
1687+
CIRGetGlobalOpLowering, CIRComplexCreateOpLowering,
1688+
CIRComplexRealOpLowering, CIRComplexImagOpLowering, CIRCastOpLowering,
1689+
CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
1690+
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1691+
CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1692+
CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1693+
CIRRoundOpLowering, CIRSinOpLowering, CIRTanOpLowering, CIRShiftOpLowering,
16101694
CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
16111695
CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
16121696
CIRIfOpLowering, CIRScopeOpLowering, CIRVectorCreateLowering,
@@ -1679,6 +1763,12 @@ static mlir::TypeConverter prepareTypeConverter() {
16791763
auto ty = converter.convertType(type.getElementType());
16801764
return mlir::VectorType::get(type.getSize(), ty);
16811765
});
1766+
converter.addConversion([&](cir::ComplexType type) -> mlir::Type {
1767+
auto elemTy = converter.convertType(type.getElementType());
1768+
if (!elemTy)
1769+
return nullptr;
1770+
return mlir::VectorType::get(2, elemTy);
1771+
});
16821772
return converter;
16831773
}
16841774

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-lowering-prepare %s -o %t.cir 2> %t-before.cir
2+
// RUN: FileCheck --input-file=%t-before.cir %s --check-prefix=CIR-BEFORE-LPP
3+
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
4+
// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll
5+
// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM
6+
// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
7+
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG
8+
9+
float num;
10+
float _Complex a = {num, num};
11+
12+
// CIR-BEFORE-LPP: cir.global external @num = #cir.fp<0.000000e+00> : !cir.float
13+
// CIR-BEFORE-LPP: cir.global external @a = ctor : !cir.complex<!cir.float> {
14+
// CIR-BEFORE-LPP: %[[THIS:.*]] = cir.get_global @a : !cir.ptr<!cir.complex<!cir.float>>
15+
// CIR-BEFORE-LPP: %[[NUM:.*]] = cir.get_global @num : !cir.ptr<!cir.float>
16+
// CIR-BEFORE-LPP: %[[REAL:.*]] = cir.load{{.*}} %[[NUM]] : !cir.ptr<!cir.float>, !cir.float
17+
// CIR-BEFORE-LPP: %[[NUM:.*]] = cir.get_global @num : !cir.ptr<!cir.float>
18+
// CIR-BEFORE-LPP: %[[IMAG:.*]] = cir.load{{.*}} %[[NUM]] : !cir.ptr<!cir.float>, !cir.float
19+
// CIR-BEFORE-LPP: %[[COMPLEX_VAL:.*]] = cir.complex.create %[[REAL]], %[[IMAG]] : !cir.float -> !cir.complex<!cir.float>
20+
// CIR-BEFORE-LPP: cir.store{{.*}} %[[COMPLEX_VAL:.*]], %[[THIS]] : !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>
21+
// CIR-BEFORE-LPP: }
22+
23+
// CIR: cir.global external @num = #cir.fp<0.000000e+00> : !cir.float
24+
// CIR: cir.global external @a = #cir.zero : !cir.complex<!cir.float>
25+
// CIR: cir.func internal private @__cxx_global_var_init()
26+
// CIR: %[[A_ADDR:.*]] = cir.get_global @a : !cir.ptr<!cir.complex<!cir.float>>
27+
// CIR: %[[NUM:.*]] = cir.get_global @num : !cir.ptr<!cir.float>
28+
// CIR: %[[REAL:.*]] = cir.load{{.*}} %[[NUM]] : !cir.ptr<!cir.float>, !cir.float
29+
// CIR: %[[NUM:.*]] = cir.get_global @num : !cir.ptr<!cir.float>
30+
// CIR: %[[IMAG:.*]] = cir.load{{.*}} %[[NUM]] : !cir.ptr<!cir.float>, !cir.float
31+
// CIR: %[[COMPLEX_VAL:.*]] = cir.complex.create %[[REAL]], %[[IMAG]] : !cir.float -> !cir.complex<!cir.float>
32+
// CIR: cir.store{{.*}} %[[COMPLEX_VAL]], %[[A_ADDR]] : !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>
33+
34+
// LLVM: define internal void @__cxx_global_var_init()
35+
// LLVM: %[[REAL:.*]] = load float, ptr @num, align 4
36+
// LLVM: %[[IMAG:.*]] = load float, ptr @num, align 4
37+
// LLVM: %[[TMP_COMPLEX_VAL:.*]] = insertvalue { float, float } {{.*}}, float %[[REAL]], 0
38+
// LLVM: %[[COMPLEX_VAL:.*]] = insertvalue { float, float } %[[TMP_COMPLEX_VAL]], float %[[IMAG]], 1
39+
// LLVM: store { float, float } %[[COMPLEX_VAL]], ptr @a, align 4
40+
41+
// OGCG: define internal void @__cxx_global_var_init() {{.*}} section ".text.startup"
42+
// OGCG: %[[REAL:.*]] = load float, ptr @num, align 4
43+
// OGCG: %[[IMAG:.*]] = load float, ptr @num, align 4
44+
// OGCG: store float %[[REAL]], ptr @a, align 4
45+
// OGCG: store float %[[IMAG]], ptr getelementptr inbounds nuw ({ float, float }, ptr @a, i32 0, i32 1), align 4

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,22 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
570570
return llvm::Constant::getNullValue(llvmType);
571571
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
572572
auto arrayAttr = dyn_cast<ArrayAttr>(attr);
573+
if (!arrayAttr) {
574+
if (auto dense = dyn_cast<DenseElementsAttr>(attr)) {
575+
SmallVector<Attribute> elementAttrs;
576+
elementAttrs.reserve(dense.getNumElements());
577+
Type elementType = dense.getElementType();
578+
if (isa<IntegerType>(elementType)) {
579+
for (const APInt &value : dense.getValues<APInt>())
580+
elementAttrs.push_back(IntegerAttr::get(elementType, value));
581+
} else if (isa<FloatType>(elementType)) {
582+
for (const APFloat &value : dense.getValues<APFloat>())
583+
elementAttrs.push_back(FloatAttr::get(elementType, value));
584+
}
585+
if (!elementAttrs.empty())
586+
arrayAttr = ArrayAttr::get(attr.getContext(), elementAttrs);
587+
}
588+
}
573589
if (!arrayAttr) {
574590
emitError(loc, "expected an array attribute for a struct constant");
575591
return nullptr;

0 commit comments

Comments
 (0)