Skip to content

Commit 3a8bcd0

Browse files
committed
[CIR] Upstream global initialization for ComplexType
1 parent a9b2998 commit 3a8bcd0

File tree

11 files changed

+257
-18
lines changed

11 files changed

+257
-18
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
8989
return cir::IntAttr::get(ty, 0);
9090
if (cir::isAnyFloatingPointType(ty))
9191
return cir::FPAttr::getZero(ty);
92+
if (auto complexType = mlir::dyn_cast<cir::ComplexType>(ty))
93+
return cir::ZeroAttr::get(complexType);
9294
if (auto arrTy = mlir::dyn_cast<cir::ArrayType>(ty))
9395
return cir::ZeroAttr::get(arrTy);
9496
if (auto vecTy = mlir::dyn_cast<cir::VectorType>(ty))

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,4 +276,38 @@ def ConstPtrAttr : CIR_Attr<"ConstPtr", "ptr", [TypedAttrInterface]> {
276276
}];
277277
}
278278

279+
//===----------------------------------------------------------------------===//
280+
// ConstComplexAttr
281+
//===----------------------------------------------------------------------===//
282+
283+
def ConstComplexAttr : CIR_Attr<"ConstComplex", "const_complex", [TypedAttrInterface]> {
284+
let summary = "An attribute that contains a constant complex value";
285+
let description = [{
286+
The `#cir.const_complex` attribute contains a constant value of complex number
287+
type. The `real` parameter gives the real part of the complex number and the
288+
`imag` parameter gives the imaginary part of the complex number.
289+
290+
The `real` and `imag` parameter must be either an IntAttr or an FPAttr that
291+
contains values of the same CIR type.
292+
}];
293+
294+
let parameters = (ins
295+
AttributeSelfTypeParameter<"", "cir::ComplexType">:$type,
296+
"mlir::TypedAttr":$real, "mlir::TypedAttr":$imag);
297+
298+
let builders = [
299+
AttrBuilderWithInferredContext<(ins "cir::ComplexType":$type,
300+
"mlir::TypedAttr":$real,
301+
"mlir::TypedAttr":$imag), [{
302+
return $_get(type.getContext(), type, real, imag);
303+
}]>,
304+
];
305+
306+
let genVerifyDecl = 1;
307+
308+
let assemblyFormat = [{
309+
`<` qualified($real) `,` qualified($imag) `>`
310+
}];
311+
}
312+
279313
#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRS_TD

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,49 @@ def CIR_LongDouble : CIR_FloatType<"LongDouble", "long_double"> {
161161
}];
162162
}
163163

164+
//===----------------------------------------------------------------------===//
165+
// ComplexType
166+
//===----------------------------------------------------------------------===//
167+
168+
def CIR_ComplexType : CIR_Type<"Complex", "complex",
169+
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {
170+
171+
let summary = "CIR complex type";
172+
let description = [{
173+
CIR type that represents a C complex number. `cir.complex` models the C type
174+
`T _Complex`.
175+
176+
The type models complex values, per C99 6.2.5p11. It supports the C99
177+
complex float types as well as the GCC integer complex extensions.
178+
179+
The parameter `elementType` gives the type of the real and imaginary part of
180+
the complex number. `elementType` must be either a CIR integer type or a CIR
181+
floating-point type.
182+
}];
183+
184+
let parameters = (ins CIR_AnyIntOrFloatType:$elementType);
185+
186+
let builders = [
187+
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
188+
return $_get(elementType.getContext(), elementType);
189+
}]>,
190+
];
191+
192+
let assemblyFormat = [{
193+
`<` $elementType `>`
194+
}];
195+
196+
let extraClassDeclaration = [{
197+
bool isFloatingPointComplex() const {
198+
return isAnyFloatingPointType(getElementType());
199+
}
200+
201+
bool isIntegerComplex() const {
202+
return mlir::isa<cir::IntType>(getElementType());
203+
}
204+
}];
205+
}
206+
164207
//===----------------------------------------------------------------------===//
165208
// PointerType
166209
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -577,12 +577,31 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value,
577577
case APValue::Union:
578578
cgm.errorNYI("ConstExprEmitter::tryEmitPrivate struct or union");
579579
return {};
580-
case APValue::FixedPoint:
581580
case APValue::ComplexInt:
582-
case APValue::ComplexFloat:
581+
case APValue::ComplexFloat: {
582+
mlir::Type desiredType = cgm.convertType(destType);
583+
cir::ComplexType complexType =
584+
mlir::dyn_cast<cir::ComplexType>(desiredType);
585+
586+
mlir::Type compelxElemTy = complexType.getElementType();
587+
if (isa<cir::IntType>(compelxElemTy)) {
588+
llvm::APSInt real = value.getComplexIntReal();
589+
llvm::APSInt imag = value.getComplexIntImag();
590+
return builder.getAttr<cir::ConstComplexAttr>(
591+
complexType, builder.getAttr<cir::IntAttr>(compelxElemTy, real),
592+
builder.getAttr<cir::IntAttr>(compelxElemTy, imag));
593+
}
594+
595+
llvm::APFloat real = value.getComplexFloatReal();
596+
llvm::APFloat imag = value.getComplexFloatImag();
597+
return builder.getAttr<cir::ConstComplexAttr>(
598+
complexType, builder.getAttr<cir::FPAttr>(compelxElemTy, real),
599+
builder.getAttr<cir::FPAttr>(compelxElemTy, imag));
600+
}
601+
case APValue::FixedPoint:
583602
case APValue::AddrLabelDiff:
584-
cgm.errorNYI("ConstExprEmitter::tryEmitPrivate fixed point, complex int, "
585-
"complex float, addr label diff");
603+
cgm.errorNYI(
604+
"ConstExprEmitter::tryEmitPrivate fixed point, addr label diff");
586605
return {};
587606
}
588607
llvm_unreachable("Unknown APValue kind");

clang/lib/CIR/CodeGen/CIRGenTypes.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,13 @@ mlir::Type CIRGenTypes::convertType(QualType type) {
385385
break;
386386
}
387387

388+
case Type::Complex: {
389+
const ComplexType *ct = cast<ComplexType>(ty);
390+
mlir::Type elementTy = convertType(ct->getElementType());
391+
resultType = cir::ComplexType::get(elementTy);
392+
break;
393+
}
394+
388395
case Type::LValueReference:
389396
case Type::RValueReference: {
390397
const ReferenceType *refTy = cast<ReferenceType>(ty);

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,26 @@ LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
184184
return success();
185185
}
186186

187+
//===----------------------------------------------------------------------===//
188+
// ConstComplexAttr definitions
189+
//===----------------------------------------------------------------------===//
190+
191+
LogicalResult
192+
ConstComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError,
193+
cir::ComplexType type, mlir::TypedAttr real,
194+
mlir::TypedAttr imag) {
195+
mlir::Type elemType = type.getElementType();
196+
if (real.getType() != elemType)
197+
return emitError()
198+
<< "type of the real part does not match the complex type";
199+
200+
if (imag.getType() != elemType)
201+
return emitError()
202+
<< "type of the imaginary part does not match the complex type";
203+
204+
return success();
205+
}
206+
187207
//===----------------------------------------------------------------------===//
188208
// CIR ConstArrayAttr
189209
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
231231
}
232232

233233
if (isa<cir::ZeroAttr>(attrType)) {
234-
if (isa<cir::RecordType, cir::ArrayType, cir::VectorType>(opType))
234+
if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
235+
opType))
235236
return success();
236237
return op->emitOpError("zero expects struct or array type");
237238
}
@@ -253,7 +254,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
253254
return success();
254255
}
255256

256-
if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr>(attrType))
257+
if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
258+
cir::ConstComplexAttr>(attrType))
257259
return success();
258260

259261
assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,32 @@ LongDoubleType::getABIAlignment(const mlir::DataLayout &dataLayout,
552552
.getABIAlignment(dataLayout, params);
553553
}
554554

555+
//===----------------------------------------------------------------------===//
556+
// ComplexType Definitions
557+
//===----------------------------------------------------------------------===//
558+
559+
llvm::TypeSize
560+
cir::ComplexType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
561+
mlir::DataLayoutEntryListRef params) const {
562+
// C17 6.2.5p13:
563+
// Each complex type has the same representation and alignment requirements
564+
// as an array type containing exactly two elements of the corresponding
565+
// real type.
566+
567+
return dataLayout.getTypeSizeInBits(getElementType()) * 2;
568+
}
569+
570+
uint64_t
571+
cir::ComplexType::getABIAlignment(const mlir::DataLayout &dataLayout,
572+
mlir::DataLayoutEntryListRef params) const {
573+
// C17 6.2.5p13:
574+
// Each complex type has the same representation and alignment requirements
575+
// as an array type containing exactly two elements of the corresponding
576+
// real type.
577+
578+
return dataLayout.getTypeABIAlignment(getElementType());
579+
}
580+
555581
//===----------------------------------------------------------------------===//
556582
// Floating-point and Float-point Vector type helpers
557583
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -188,14 +188,15 @@ class CIRAttrToValue {
188188

189189
mlir::Value visit(mlir::Attribute attr) {
190190
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
191-
.Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr,
192-
cir::ConstVectorAttr, cir::ConstPtrAttr, cir::ZeroAttr>(
193-
[&](auto attrT) { return visitCirAttr(attrT); })
191+
.Case<cir::IntAttr, cir::FPAttr, cir::ConstComplexAttr,
192+
cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
193+
cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); })
194194
.Default([&](auto attrT) { return mlir::Value(); });
195195
}
196196

197197
mlir::Value visitCirAttr(cir::IntAttr intAttr);
198198
mlir::Value visitCirAttr(cir::FPAttr fltAttr);
199+
mlir::Value visitCirAttr(cir::ConstComplexAttr complexAttr);
199200
mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
200201
mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
201202
mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
@@ -226,6 +227,42 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) {
226227
loc, converter->convertType(intAttr.getType()), intAttr.getValue());
227228
}
228229

230+
/// FPAttr visitor.
231+
mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
232+
mlir::Location loc = parentOp->getLoc();
233+
return rewriter.create<mlir::LLVM::ConstantOp>(
234+
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
235+
}
236+
237+
/// ConstComplexAttr visitor.
238+
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstComplexAttr complexAttr) {
239+
auto complexType = mlir::cast<cir::ComplexType>(complexAttr.getType());
240+
auto complexElemTy = complexType.getElementType();
241+
auto complexElemLLVMTy = converter->convertType(complexElemTy);
242+
243+
mlir::Attribute components[2];
244+
if (const auto intType = mlir::dyn_cast<cir::IntType>(complexElemTy)) {
245+
components[0] = rewriter.getIntegerAttr(
246+
complexElemLLVMTy,
247+
mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue());
248+
components[1] = rewriter.getIntegerAttr(
249+
complexElemLLVMTy,
250+
mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue());
251+
} else {
252+
components[0] = rewriter.getFloatAttr(
253+
complexElemLLVMTy,
254+
mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue());
255+
components[1] = rewriter.getFloatAttr(
256+
complexElemLLVMTy,
257+
mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue());
258+
}
259+
260+
mlir::Location loc = parentOp->getLoc();
261+
return rewriter.create<mlir::LLVM::ConstantOp>(
262+
loc, converter->convertType(complexAttr.getType()),
263+
rewriter.getArrayAttr(components));
264+
}
265+
229266
/// ConstPtrAttr visitor.
230267
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
231268
mlir::Location loc = parentOp->getLoc();
@@ -241,13 +278,6 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
241278
loc, converter->convertType(ptrAttr.getType()), ptrVal);
242279
}
243280

244-
/// FPAttr visitor.
245-
mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
246-
mlir::Location loc = parentOp->getLoc();
247-
return rewriter.create<mlir::LLVM::ConstantOp>(
248-
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
249-
}
250-
251281
// ConstArrayAttr visitor
252282
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
253283
mlir::Type llvmTy = converter->convertType(attr.getType());
@@ -341,9 +371,11 @@ class GlobalInitAttrRewriter {
341371
mlir::Attribute visitCirAttr(cir::IntAttr attr) {
342372
return rewriter.getIntegerAttr(llvmType, attr.getValue());
343373
}
374+
344375
mlir::Attribute visitCirAttr(cir::FPAttr attr) {
345376
return rewriter.getFloatAttr(llvmType, attr.getValue());
346377
}
378+
347379
mlir::Attribute visitCirAttr(cir::BoolAttr attr) {
348380
return rewriter.getBoolAttr(attr.getValue());
349381
}
@@ -990,7 +1022,7 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
9901022
mlir::ConversionPatternRewriter &rewriter) const {
9911023
// TODO: Generalize this handling when more types are needed here.
9921024
assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
993-
cir::ZeroAttr>(init)));
1025+
cir::ConstComplexAttr, cir::ZeroAttr>(init)));
9941026

9951027
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
9961028
// should be updated. For now, we use a custom op to initialize globals
@@ -1043,7 +1075,8 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
10431075
return mlir::failure();
10441076
}
10451077
} else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
1046-
cir::ConstPtrAttr, cir::ZeroAttr>(init.value())) {
1078+
cir::ConstPtrAttr, cir::ConstComplexAttr,
1079+
cir::ZeroAttr>(init.value())) {
10471080
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
10481081
// should be updated. For now, we use a custom op to initialize globals
10491082
// to the appropriate value.
@@ -1559,6 +1592,14 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
15591592
converter.addConversion([&](cir::BF16Type type) -> mlir::Type {
15601593
return mlir::BFloat16Type::get(type.getContext());
15611594
});
1595+
converter.addConversion([&](cir::ComplexType type) -> mlir::Type {
1596+
// A complex type is lowered to an LLVM struct that contains the real and
1597+
// imaginary part as data fields.
1598+
mlir::Type elementTy = converter.convertType(type.getElementType());
1599+
mlir::Type structFields[2] = {elementTy, elementTy};
1600+
return mlir::LLVM::LLVMStructType::getLiteral(type.getContext(),
1601+
structFields);
1602+
});
15621603
converter.addConversion([&](cir::FuncType type) -> std::optional<mlir::Type> {
15631604
auto result = converter.convertType(type.getReturnType());
15641605
llvm::SmallVector<mlir::Type> arguments;

clang/test/CIR/CodeGen/complex.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
4+
// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
5+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -emit-llvm %s -o %t.ll
6+
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
7+
8+
int _Complex ci;
9+
10+
float _Complex cf;
11+
12+
int _Complex ci2 = { 1, 2 };
13+
14+
float _Complex cf2 = { 1.0f, 2.0f };
15+
16+
// CIR: cir.global external @ci = #cir.zero : !cir.complex<!s32i>
17+
// CIR: cir.global external @cf = #cir.zero : !cir.complex<!cir.float>
18+
// CIR: cir.global external @ci2 = #cir.const_complex<#cir.int<1> : !s32i, #cir.int<2> : !s32i> : !cir.complex<!s32i>
19+
// CIR: cir.global external @cf2 = #cir.const_complex<#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00> : !cir.float> : !cir.complex<!cir.float>
20+
21+
// LLVM: {{.*}} = dso_local global { i32, i32 } zeroinitializer, align 4
22+
// LLVM: {{.*}} = dso_local global { float, float } zeroinitializer, align 4
23+
// LLVM: {{.*}} = dso_local global { i32, i32 } { i32 1, i32 2 }, align 4
24+
// LLVM: {{.*}} = dso_local global { float, float } { float 1.000000e+00, float 2.000000e+00 }, align 4
25+
26+
// OGCG: {{.*}} = global { i32, i32 } zeroinitializer, align 4
27+
// OGCG: {{.*}} = global { float, float } zeroinitializer, align 4
28+
// OGCG: {{.*}} = global { i32, i32 } { i32 1, i32 2 }, align 4
29+
// OGCG: {{.*}} = global { float, float } { float 1.000000e+00, float 2.000000e+00 }, align 4

0 commit comments

Comments
 (0)