Skip to content

Commit b5bc285

Browse files
committed
Support enum class in struct-backed type
We may use enumeration class in struct backed type, which makes struct type more semantic. However, we need to force these values to be cast from integers to the target type.
1 parent 8620e69 commit b5bc285

File tree

5 files changed

+29
-19
lines changed

5 files changed

+29
-19
lines changed

example/ExampleDialect.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def StructBackedType : DialectType<ExampleDialect, "struct.backed"> {
367367
let description = [{
368368
Test that a struct-backed type works correctly.
369369
}];
370-
let typeArguments = (args AttrI32:$field0, AttrI32:$field1, AttrI32:$field2);
370+
let typeArguments = (args AttrI32:$field0, AttrI8:$field1, AttrVectorKind:$field2);
371371
let representation = (repr_struct (IntegerType 41));
372372

373373
let defaultGetterHasExplicitContextArgument = 1;

example/ExampleMain.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ void createFunctionExample(Module &module, const Twine &name) {
149149

150150
b.create<xd::cpp::StringAttrOp>("Hello world!");
151151

152-
xd::cpp::StructBackedType *structBackedTy = xd::cpp::StructBackedType::get(bb->getContext(), 1, 0, 2);
152+
xd::cpp::StructBackedType *structBackedTy = xd::cpp::StructBackedType::get(bb->getContext(), 1, 0, xd::cpp::VectorKind::BigEndian);
153153
auto *structBackedVal = b.create<xd::cpp::DummyStructBackedOutpOp>(structBackedTy, b.getInt32(42), "gen.struct.backed.val");
154154
b.create<xd::cpp::DummyStructBackedInpOp>(structBackedVal, "consume.struct.backed.val");
155155

lib/TableGen/DialectType.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,17 +201,24 @@ void DialectType::emitDeclaration(raw_ostream &out, GenDialect *dialect) const {
201201
out << " static bool classof(const ::llvm::Type *t);\n\n";
202202

203203
unsigned fieldIdx = 1; // sentinel
204+
auto getCastExpr = [&fmt](const NamedValue &argument,
205+
llvm::StringRef expr) -> std::string {
206+
return tgfmt(cast<Attr>(argument.type)->getFromUnsigned(), &fmt, expr);
207+
};
204208
for (const auto &argument : typeArguments()) {
205209
std::string camel = convertToCamelFromSnakeCase(argument.name, true);
206210
out << tgfmt(
207-
R"( unsigned get$0() const {
208-
::llvm::Type *elt = getElementType($1);
211+
R"( $0 get$1() const {
212+
::llvm::Type *elt = getElementType($2);
209213
if (elt->isStructTy())
210-
return 0;
211-
return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth();
214+
return $3;
215+
return $4;
212216
}
213217
)",
214-
&fmt, camel, fieldIdx++);
218+
&fmt, argument.type->getCppType(), camel, fieldIdx++,
219+
getCastExpr(argument, "0"),
220+
getCastExpr(argument,
221+
"::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth()"));
215222
}
216223

217224
out << " };\n\n";
@@ -307,14 +314,17 @@ void DialectType::emitDefinition(raw_ostream &out, GenDialect *dialect) const {
307314
" $fields.push_back(::llvm::IntegerType::get($_context, $0));\n", &fmt,
308315
Twine(m_structSentinelBitWidth));
309316

310-
for (const auto &getterArg : getterArgs) {
317+
for (const auto &[argument, getterArg] :
318+
llvm::zip(typeArguments(), getterArgs)) {
319+
std::string castExpr = tgfmt(cast<Attr>(argument.type)->getToUnsigned(),
320+
&fmt, getterArg.name);
311321
out << tgfmt(R"(
312322
if ($0 == 0)
313323
$fields.push_back(::llvm::StructType::get($_context));
314324
else
315325
$fields.push_back(::llvm::IntegerType::get($_context, $0));
316326
)",
317-
&fmt, getterArg.name);
327+
&fmt, castExpr);
318328
}
319329
out << tgfmt(" auto *$st = ::llvm::StructType::create($_context, "
320330
"$fields, $os.str(), /*isPacked=*/false);\n",

test/example/generated/ExampleDialect.cpp.inc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,10 @@ m_attributeLists[6] = argAttrList.addFnAttributes(context, attrBuilder);
258258
}
259259
}
260260

261-
StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t field0, uint32_t field1, uint32_t field2) {
262-
261+
StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t field0, uint8_t field1, VectorKind field2) {
263262

264263

264+
static_assert(sizeof(field2) <= sizeof(unsigned));
265265
std::string name; ::llvm::raw_string_ostream os(name);
266266
os << "struct.backed";
267267
os << '.' << (uint64_t)field0;
@@ -280,10 +280,10 @@ StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t fiel
280280
else
281281
fields.push_back(::llvm::IntegerType::get(ctx, field1));
282282

283-
if (field2 == 0)
283+
if (static_cast<unsigned>(field2) == 0)
284284
fields.push_back(::llvm::StructType::get(ctx));
285285
else
286-
fields.push_back(::llvm::IntegerType::get(ctx, field2));
286+
fields.push_back(::llvm::IntegerType::get(ctx, static_cast<unsigned>(field2)));
287287
auto *st = ::llvm::StructType::create(ctx, fields, os.str(), /*isPacked=*/false);
288288
return static_cast<StructBackedType *>(st);
289289
}

test/example/generated/ExampleDialect.h.inc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,27 +58,27 @@ namespace xd::cpp {
5858
using ::llvm::StructType::getElementType;
5959

6060
static StructBackedType *get(
61-
::llvm::LLVMContext & ctx, uint32_t field0, uint32_t field1, uint32_t field2);
61+
::llvm::LLVMContext & ctx, uint32_t field0, uint8_t field1, VectorKind field2);
6262

6363
static bool classof(const ::llvm::Type *t);
6464

65-
unsigned getField0() const {
65+
uint32_t getField0() const {
6666
::llvm::Type *elt = getElementType(1);
6767
if (elt->isStructTy())
6868
return 0;
6969
return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth();
7070
}
71-
unsigned getField1() const {
71+
uint8_t getField1() const {
7272
::llvm::Type *elt = getElementType(2);
7373
if (elt->isStructTy())
7474
return 0;
7575
return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth();
7676
}
77-
unsigned getField2() const {
77+
VectorKind getField2() const {
7878
::llvm::Type *elt = getElementType(3);
7979
if (elt->isStructTy())
80-
return 0;
81-
return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth();
80+
return static_cast<VectorKind>(0);
81+
return static_cast<VectorKind>(::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth());
8282
}
8383
};
8484

0 commit comments

Comments
 (0)