Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,44 @@ def CIR_ConstVectorAttr : CIR_Attr<"ConstVector", "const_vector", [
let genVerifyDecl = 1;
}

//===----------------------------------------------------------------------===//
// ConstRecordAttr
//===----------------------------------------------------------------------===//

def CIR_ConstRecordAttr : CIR_Attr<"ConstRecord", "const_record", [
TypedAttrInterface
]> {
let summary = "Represents a constant record";
let description = [{
Effectively supports "struct-like" constants. It's must be built from
an `mlir::ArrayAttr` instance where each element is a typed attribute
(`mlir::TypedAttribute`).

Example:
```
cir.global external @rgb2 = #cir.const_record<{0 : i8,
5 : i64, #cir.null : !cir.ptr<i8>
}> : !cir.record<"", i8, i64, !cir.ptr<i8>>
```
}];

let parameters = (ins AttributeSelfTypeParameter<"">:$type,
"mlir::ArrayAttr":$members);

let builders = [
AttrBuilderWithInferredContext<(ins "cir::RecordType":$type,
"mlir::ArrayAttr":$members), [{
return $_get(type.getContext(), type, members);
}]>
];

let assemblyFormat = [{
`<` custom<RecordMembers>($members) `>`
}];

let genVerifyDecl = 1;
}

//===----------------------------------------------------------------------===//
// ConstPtrAttr
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 17 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,23 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
trailingZerosNum);
}

cir::ConstRecordAttr getAnonConstRecord(mlir::ArrayAttr arrayAttr,
bool packed = false,
bool padded = false,
mlir::Type ty = {}) {
llvm::SmallVector<mlir::Type, 4> members;
for (auto &f : arrayAttr) {
auto ta = mlir::cast<mlir::TypedAttr>(f);
members.push_back(ta.getType());
}

if (!ty)
ty = getAnonRecordTy(members, packed, padded);

auto sTy = mlir::cast<cir::RecordType>(ty);
return cir::ConstRecordAttr::get(sTy, arrayAttr);
}

std::string getUniqueAnonRecordName() { return getUniqueRecordName("anon"); }

std::string getUniqueRecordName(const std::string &baseName) {
Expand Down
38 changes: 35 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ emitArrayConstant(CIRGenModule &cgm, mlir::Type desiredType,
mlir::Type commonElementType, unsigned arrayBound,
SmallVectorImpl<mlir::TypedAttr> &elements,
mlir::TypedAttr filler) {
const CIRGenBuilderTy &builder = cgm.getBuilder();
CIRGenBuilderTy &builder = cgm.getBuilder();

unsigned nonzeroLength = arrayBound;
if (elements.size() < nonzeroLength && builder.isNullValue(filler))
Expand All @@ -306,6 +306,33 @@ emitArrayConstant(CIRGenModule &cgm, mlir::Type desiredType,
if (trailingZeroes >= 8) {
assert(elements.size() >= nonzeroLength &&
"missing initializer for non-zero element");

if (commonElementType && nonzeroLength >= 8) {
// If all the elements had the same type up to the trailing zeroes and
// there are eight or more nonzero elements, emit a struct of two arrays
// (the nonzero data and the zeroinitializer).
SmallVector<mlir::Attribute, 4> eles;
eles.reserve(nonzeroLength);
for (const auto &element : elements)
eles.push_back(element);
auto initial = cir::ConstArrayAttr::get(
cir::ArrayType::get(commonElementType, nonzeroLength),
mlir::ArrayAttr::get(builder.getContext(), eles));
elements.resize(2);
elements[0] = initial;
} else {
// Otherwise, emit a struct with individual elements for each nonzero
// initializer, followed by a zeroinitializer array filler.
elements.resize(nonzeroLength + 1);
}

mlir::Type fillerType =
commonElementType
? commonElementType
: mlir::cast<cir::ArrayType>(desiredType).getElementType();
fillerType = cir::ArrayType::get(fillerType, trailingZeroes);
elements.back() = cir::ZeroAttr::get(fillerType);
commonElementType = nullptr;
} else if (elements.size() != arrayBound) {
elements.resize(arrayBound, filler);

Expand All @@ -325,8 +352,13 @@ emitArrayConstant(CIRGenModule &cgm, mlir::Type desiredType,
mlir::ArrayAttr::get(builder.getContext(), eles));
}

cgm.errorNYI("array with different type elements");
return {};
SmallVector<mlir::Attribute, 4> eles;
eles.reserve(elements.size());
for (auto const &element : elements)
eles.push_back(element);

auto arrAttr = mlir::ArrayAttr::get(builder.getContext(), eles);
return builder.getAnonConstRecord(arrAttr, /*isPacked=*/true);
}

//===----------------------------------------------------------------------===//
Expand Down
63 changes: 63 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"

//===-----------------------------------------------------------------===//
// RecordMembers
//===-----------------------------------------------------------------===//

static void printRecordMembers(mlir::AsmPrinter &p, mlir::ArrayAttr members);
static mlir::ParseResult parseRecordMembers(mlir::AsmParser &parser,
mlir::ArrayAttr &members);

//===-----------------------------------------------------------------===//
// IntLiteral
//===-----------------------------------------------------------------===//
Expand Down Expand Up @@ -68,6 +76,61 @@ void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
llvm_unreachable("unexpected CIR type kind");
}

static void printRecordMembers(mlir::AsmPrinter &printer,
mlir::ArrayAttr members) {
printer << '{';
llvm::interleaveComma(members, printer);
printer << '}';
}

static ParseResult parseRecordMembers(mlir::AsmParser &parser,
mlir::ArrayAttr &members) {
llvm::SmallVector<mlir::Attribute, 4> elts;

auto delimiter = AsmParser::Delimiter::Braces;
auto result = parser.parseCommaSeparatedList(delimiter, [&]() {
mlir::TypedAttr attr;
if (parser.parseAttribute(attr).failed())
return mlir::failure();
elts.push_back(attr);
return mlir::success();
});

if (result.failed())
return mlir::failure();

members = mlir::ArrayAttr::get(parser.getContext(), elts);
return mlir::success();
}

//===----------------------------------------------------------------------===//
// ConstRecordAttr definitions
//===----------------------------------------------------------------------===//

LogicalResult
ConstRecordAttr::verify(function_ref<InFlightDiagnostic()> emitError,
mlir::Type type, ArrayAttr members) {
auto sTy = mlir::dyn_cast_if_present<cir::RecordType>(type);
if (!sTy)
return emitError() << "expected !cir.record type";

if (sTy.getMembers().size() != members.size())
return emitError() << "number of elements must match";

unsigned attrIdx = 0;
for (auto &member : sTy.getMembers()) {
auto m = mlir::cast<mlir::TypedAttr>(members[attrIdx]);
if (member != m.getType())
return emitError() << "element at index " << attrIdx << " has type "
<< m.getType()
<< " but the expected type for this element is "
<< member;
attrIdx++;
}

return success();
}

//===----------------------------------------------------------------------===//
// OptInfoAttr definitions
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
}

if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
cir::ConstComplexAttr, cir::GlobalViewAttr, cir::PoisonAttr>(
attrType))
cir::ConstComplexAttr, cir::ConstRecordAttr,
cir::GlobalViewAttr, cir::PoisonAttr>(attrType))
return success();

assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
Expand Down
36 changes: 29 additions & 7 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ class CIRAttrToValue {
mlir::Value visit(mlir::Attribute attr) {
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
.Case<cir::IntAttr, cir::FPAttr, cir::ConstComplexAttr,
cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
cir::GlobalViewAttr, cir::ZeroAttr>(
cir::ConstArrayAttr, cir::ConstRecordAttr, cir::ConstVectorAttr,
cir::ConstPtrAttr, cir::GlobalViewAttr, cir::ZeroAttr>(
[&](auto attrT) { return visitCirAttr(attrT); })
.Default([&](auto attrT) { return mlir::Value(); });
}
Expand All @@ -212,6 +212,7 @@ class CIRAttrToValue {
mlir::Value visitCirAttr(cir::ConstComplexAttr complexAttr);
mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
mlir::Value visitCirAttr(cir::ConstRecordAttr attr);
mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
mlir::Value visitCirAttr(cir::GlobalViewAttr attr);
mlir::Value visitCirAttr(cir::ZeroAttr attr);
Expand Down Expand Up @@ -386,6 +387,21 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
return result;
}

/// ConstRecord visitor.
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstRecordAttr constRecord) {
const mlir::Type llvmTy = converter->convertType(constRecord.getType());
const mlir::Location loc = parentOp->getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);

// Iteratively lower each constant element of the record.
for (auto [idx, elt] : llvm::enumerate(constRecord.getMembers())) {
mlir::Value init = visit(elt);
result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
}

return result;
}

/// ConstVectorAttr visitor.
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) {
const mlir::Type llvmTy = converter->convertType(attr.getType());
Expand Down Expand Up @@ -1286,6 +1302,11 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
rewriter.eraseOp(op);
return mlir::success();
}
} else if (const auto recordAttr =
mlir::dyn_cast<cir::ConstRecordAttr>(op.getValue())) {
auto initVal = lowerCirAttrAsValue(op, recordAttr, rewriter, typeConverter);
rewriter.replaceOp(op, initVal);
return mlir::success();
} else if (const auto vecTy = mlir::dyn_cast<cir::VectorType>(op.getType())) {
rewriter.replaceOp(op, lowerCirAttrAsValue(op, op.getValue(), rewriter,
getTypeConverter()));
Expand Down Expand Up @@ -1527,9 +1548,9 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
cir::GlobalOp op, mlir::Attribute init,
mlir::ConversionPatternRewriter &rewriter) const {
// TODO: Generalize this handling when more types are needed here.
assert(
(isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
cir::ConstComplexAttr, cir::GlobalViewAttr, cir::ZeroAttr>(init)));
assert((isa<cir::ConstArrayAttr, cir::ConstRecordAttr, cir::ConstVectorAttr,
cir::ConstPtrAttr, cir::ConstComplexAttr, cir::GlobalViewAttr,
cir::ZeroAttr>(init)));

// TODO(cir): once LLVM's dialect has proper equivalent attributes this
// should be updated. For now, we use a custom op to initialize globals
Expand Down Expand Up @@ -1582,8 +1603,9 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
return mlir::failure();
}
} else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
cir::ConstPtrAttr, cir::ConstComplexAttr,
cir::GlobalViewAttr, cir::ZeroAttr>(init.value())) {
cir::ConstRecordAttr, cir::ConstPtrAttr,
cir::ConstComplexAttr, cir::GlobalViewAttr,
cir::ZeroAttr>(init.value())) {
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
// should be updated. For now, we use a custom op to initialize globals
// to the appropriate value.
Expand Down
26 changes: 24 additions & 2 deletions clang/test/CIR/CodeGen/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ int dd[3][2] = {{1, 2}, {3, 4}, {5, 6}};
// OGCG: [i32 3, i32 4], [2 x i32] [i32 5, i32 6]]

int e[10] = {1, 2};
// CIR: cir.global external @e = #cir.const_array<[#cir.int<1> : !s32i, #cir.int<2> : !s32i], trailing_zeros> : !cir.array<!s32i x 10>
// CIR: cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct

// LLVM: @e = global [10 x i32] [i32 1, i32 2, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0]
// LLVM: @e = global <{ i32, i32, [8 x i32] }> <{ i32 1, i32 2, [8 x i32] zeroinitializer }>

// OGCG: @e = global <{ i32, i32, [8 x i32] }> <{ i32 1, i32 2, [8 x i32] zeroinitializer }>

Expand All @@ -58,6 +58,28 @@ int f[5] = {1, 2};

// OGCG: @f = global [5 x i32] [i32 1, i32 2, i32 0, i32 0, i32 0]

int g[16] = {1, 2, 3, 4, 5, 6, 7, 8};
// CIR: cir.global external @g = #cir.const_record<{
// CIR-SAME: #cir.const_array<[#cir.int<1> : !s32i, #cir.int<2> : !s32i,
// CIR-SAME: #cir.int<3> : !s32i, #cir.int<4> : !s32i,
// CIR-SAME: #cir.int<5> : !s32i, #cir.int<6> : !s32i,
// CIR-SAME: #cir.int<7> : !s32i, #cir.int<8> : !s32i]>
// CIR-SAME: : !cir.array<!s32i x 8>,
// CIR-SAME: #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct1

// LLVM: @g = global <{ [8 x i32], [8 x i32] }>
// LLVM-SAME: <{ [8 x i32]
// LLVM-SAME: [i32 1, i32 2, i32 3, i32 4,
// LLVM-SAME: i32 5, i32 6, i32 7, i32 8],
// LLVM-SAME: [8 x i32] zeroinitializer }>

// OGCG: @g = global <{ [8 x i32], [8 x i32] }>
// OGCG-SAME: <{ [8 x i32]
// OGCG-SAME: [i32 1, i32 2, i32 3, i32 4,
// OGCG-SAME: i32 5, i32 6, i32 7, i32 8],
// OGCG-SAME: [8 x i32] zeroinitializer }>


extern int b[10];
// CIR: cir.global "private" external @b : !cir.array<!s32i x 10>
// LLVM: @b = external global [10 x i32]
Expand Down
23 changes: 23 additions & 0 deletions clang/test/CIR/IR/invalid-const-record.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: cir-opt %s -verify-diagnostics -split-input-file

!s32i = !cir.int<s, 32>
!rec_anon_struct = !cir.record<struct packed {!s32i, !s32i, !cir.array<!s32i x 8>}>

// expected-error @below {{expected !cir.record type}}
cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !cir.ptr<!rec_anon_struct>

// -----

!s32i = !cir.int<s, 32>
!rec_anon_struct = !cir.record<struct packed {!s32i, !s32i, !cir.array<!s32i x 8>}>

// expected-error @below {{number of elements must match}}
cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct

// -----

!s32i = !cir.int<s, 32>
!rec_anon_struct = !cir.record<struct packed {!s32i, !s32i, !cir.array<!s32i x 8>}>

// expected-error @below {{element at index 1 has type '!cir.float' but the expected type for this element is '!cir.int<s, 32>'}}
cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.fp<2.000000e+00> : !cir.float, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct
13 changes: 9 additions & 4 deletions clang/test/CIR/IR/struct.cir
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
// CHECK-DAG: !rec_S = !cir.record<struct "S" incomplete>
// CHECK-DAG: !rec_U = !cir.record<union "U" incomplete>

!rec_anon_struct = !cir.record<struct {!cir.array<!cir.ptr<!u8i> x 5>}>
!rec_anon_struct1 = !cir.record<struct {!cir.ptr<!u8i>, !cir.ptr<!u8i>, !cir.ptr<!u8i>}>
!rec_anon_struct = !cir.record<struct packed {!s32i, !s32i, !cir.array<!s32i x 8>}>
!rec_anon_struct1 = !cir.record<struct {!cir.array<!cir.ptr<!u8i> x 5>}>
!rec_anon_struct2 = !cir.record<struct {!cir.ptr<!u8i>, !cir.ptr<!u8i>, !cir.ptr<!u8i>}>
!rec_S1 = !cir.record<struct "S1" {!s32i, !s32i}>
!rec_Sc = !cir.record<struct "Sc" {!u8i, !u16i, !u32i}>

Expand Down Expand Up @@ -42,18 +43,22 @@
!rec_Node = !cir.record<struct "Node" {!cir.ptr<!cir.record<struct "Node">>}>
// CHECK-DAG: !cir.record<struct "Node" {!cir.ptr<!cir.record<struct "Node">>}>



module {
cir.global external @p1 = #cir.ptr<null> : !cir.ptr<!rec_S>
cir.global external @p2 = #cir.ptr<null> : !cir.ptr<!rec_U>
cir.global external @p3 = #cir.ptr<null> : !cir.ptr<!rec_C>
cir.global external @arr = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct
// CHECK: cir.global external @p1 = #cir.ptr<null> : !cir.ptr<!rec_S>
// CHECK: cir.global external @p2 = #cir.ptr<null> : !cir.ptr<!rec_U>
// CHECK: cir.global external @p3 = #cir.ptr<null> : !cir.ptr<!rec_C>
// CHECK: cir.global external @arr = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct

// Dummy function to use types and force them to be printed.
cir.func @useTypes(%arg0: !rec_Node,
%arg1: !rec_anon_struct1,
%arg2: !rec_anon_struct,
%arg1: !rec_anon_struct2,
%arg2: !rec_anon_struct1,
%arg3: !rec_S1,
%arg4: !rec_Ac,
%arg5: !rec_P1,
Expand Down
Loading