Skip to content

Commit f6f1a2e

Browse files
authored
Support enum class in struct-backed type (#133)
We may use enumeration class in struct backed type, which makes struct type more semantic. However, we need to force these values to be cast from integers to the target type.
1 parent 8620e69 commit f6f1a2e

File tree

5 files changed

+29
-19
lines changed

5 files changed

+29
-19
lines changed

example/ExampleDialect.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def StructBackedType : DialectType<ExampleDialect, "struct.backed"> {
367367
let description = [{
368368
Test that a struct-backed type works correctly.
369369
}];
370-
let typeArguments = (args AttrI32:$field0, AttrI32:$field1, AttrI32:$field2);
370+
let typeArguments = (args AttrI32:$field0, AttrI8:$field1, AttrVectorKind:$field2);
371371
let representation = (repr_struct (IntegerType 41));
372372

373373
let defaultGetterHasExplicitContextArgument = 1;

example/ExampleMain.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ void createFunctionExample(Module &module, const Twine &name) {
149149

150150
b.create<xd::cpp::StringAttrOp>("Hello world!");
151151

152-
xd::cpp::StructBackedType *structBackedTy = xd::cpp::StructBackedType::get(bb->getContext(), 1, 0, 2);
152+
xd::cpp::StructBackedType *structBackedTy = xd::cpp::StructBackedType::get(bb->getContext(), 1, 0, xd::cpp::VectorKind::BigEndian);
153153
auto *structBackedVal = b.create<xd::cpp::DummyStructBackedOutpOp>(structBackedTy, b.getInt32(42), "gen.struct.backed.val");
154154
b.create<xd::cpp::DummyStructBackedInpOp>(structBackedVal, "consume.struct.backed.val");
155155

lib/TableGen/DialectType.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,17 +201,24 @@ void DialectType::emitDeclaration(raw_ostream &out, GenDialect *dialect) const {
201201
out << " static bool classof(const ::llvm::Type *t);\n\n";
202202

203203
unsigned fieldIdx = 1; // sentinel
204+
auto getCastExpr = [&fmt](const NamedValue &argument,
205+
llvm::StringRef expr) -> std::string {
206+
return tgfmt(cast<Attr>(argument.type)->getFromUnsigned(), &fmt, expr);
207+
};
204208
for (const auto &argument : typeArguments()) {
205209
std::string camel = convertToCamelFromSnakeCase(argument.name, true);
206210
out << tgfmt(
207-
R"( unsigned get$0() const {
208-
::llvm::Type *elt = getElementType($1);
211+
R"( $0 get$1() const {
212+
::llvm::Type *elt = getElementType($2);
209213
if (elt->isStructTy())
210-
return 0;
211-
return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth();
214+
return $3;
215+
return $4;
212216
}
213217
)",
214-
&fmt, camel, fieldIdx++);
218+
&fmt, argument.type->getCppType(), camel, fieldIdx++,
219+
getCastExpr(argument, "0"),
220+
getCastExpr(argument,
221+
"::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth()"));
215222
}
216223

217224
out << " };\n\n";
@@ -307,14 +314,17 @@ void DialectType::emitDefinition(raw_ostream &out, GenDialect *dialect) const {
307314
" $fields.push_back(::llvm::IntegerType::get($_context, $0));\n", &fmt,
308315
Twine(m_structSentinelBitWidth));
309316

310-
for (const auto &getterArg : getterArgs) {
317+
for (const auto &[argument, getterArg] :
318+
llvm::zip(typeArguments(), getterArgs)) {
319+
std::string castExpr = tgfmt(cast<Attr>(argument.type)->getToUnsigned(),
320+
&fmt, getterArg.name);
311321
out << tgfmt(R"(
312322
if ($0 == 0)
313323
$fields.push_back(::llvm::StructType::get($_context));
314324
else
315325
$fields.push_back(::llvm::IntegerType::get($_context, $0));
316326
)",
317-
&fmt, getterArg.name);
327+
&fmt, castExpr);
318328
}
319329
out << tgfmt(" auto *$st = ::llvm::StructType::create($_context, "
320330
"$fields, $os.str(), /*isPacked=*/false);\n",

test/example/generated/ExampleDialect.cpp.inc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,10 @@ m_attributeLists[6] = argAttrList.addFnAttributes(context, attrBuilder);
258258
}
259259
}
260260

261-
StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t field0, uint32_t field1, uint32_t field2) {
262-
261+
StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t field0, uint8_t field1, VectorKind field2) {
263262

264263

264+
static_assert(sizeof(field2) <= sizeof(unsigned));
265265
std::string name; ::llvm::raw_string_ostream os(name);
266266
os << "struct.backed";
267267
os << '.' << (uint64_t)field0;
@@ -280,10 +280,10 @@ StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t fiel
280280
else
281281
fields.push_back(::llvm::IntegerType::get(ctx, field1));
282282

283-
if (field2 == 0)
283+
if (static_cast<unsigned>(field2) == 0)
284284
fields.push_back(::llvm::StructType::get(ctx));
285285
else
286-
fields.push_back(::llvm::IntegerType::get(ctx, field2));
286+
fields.push_back(::llvm::IntegerType::get(ctx, static_cast<unsigned>(field2)));
287287
auto *st = ::llvm::StructType::create(ctx, fields, os.str(), /*isPacked=*/false);
288288
return static_cast<StructBackedType *>(st);
289289
}

test/example/generated/ExampleDialect.h.inc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,27 +58,27 @@ namespace xd::cpp {
5858
using ::llvm::StructType::getElementType;
5959

6060
static StructBackedType *get(
61-
::llvm::LLVMContext & ctx, uint32_t field0, uint32_t field1, uint32_t field2);
61+
::llvm::LLVMContext & ctx, uint32_t field0, uint8_t field1, VectorKind field2);
6262

6363
static bool classof(const ::llvm::Type *t);
6464

65-
unsigned getField0() const {
65+
uint32_t getField0() const {
6666
::llvm::Type *elt = getElementType(1);
6767
if (elt->isStructTy())
6868
return 0;
6969
return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth();
7070
}
71-
unsigned getField1() const {
71+
uint8_t getField1() const {
7272
::llvm::Type *elt = getElementType(2);
7373
if (elt->isStructTy())
7474
return 0;
7575
return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth();
7676
}
77-
unsigned getField2() const {
77+
VectorKind getField2() const {
7878
::llvm::Type *elt = getElementType(3);
7979
if (elt->isStructTy())
80-
return 0;
81-
return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth();
80+
return static_cast<VectorKind>(0);
81+
return static_cast<VectorKind>(::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth());
8282
}
8383
};
8484

0 commit comments

Comments
 (0)