Skip to content

Commit 1167e32

Browse files
committed
tablegen: add AttrEnum
Allows convenient use of C++ enums as attribute operands of operations and dialect types. Note that users were able to just mostly implement this themselves already, purely using .td code, but there were some failures in type casting.
1 parent b716b56 commit 1167e32

File tree

11 files changed

+94
-33
lines changed

11 files changed

+94
-33
lines changed

example/ExampleDialect.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,17 @@
2626
#pragma once
2727

2828
#define GET_INCLUDES
29+
#include "ExampleDialect.h.inc"
30+
31+
namespace xd {
32+
33+
enum class VectorKind {
34+
LittleEndian = 0,
35+
BigEndian = 1,
36+
MiddleEndian = 2,
37+
};
38+
39+
} // namespace xd
40+
2941
#define GET_DIALECT_DECLS
3042
#include "ExampleDialect.h.inc"

example/ExampleDialect.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@ def ExampleDialect : Dialect {
3030
let cppNamespace = "xd";
3131
}
3232

33+
defm AttrVectorKind : AttrEnum<"VectorKind">;
34+
3335
def XdVectorType : DialectType<ExampleDialect, "vector"> {
34-
let typeArguments = (args type:$element_type, AttrI32:$num_elements);
36+
let typeArguments = (args AttrVectorKind:$kind, type:$element_type,
37+
AttrI32:$num_elements);
3538

3639
let summary = "a custom vector type";
3740
let description = [{
@@ -114,7 +117,7 @@ def SizeOfOp : ExampleOp<"sizeof", [Memory<[]>, NoUnwind, WillReturn]> {
114117
def ExtractElementOp : ExampleOp<"extractelement", [Memory<[]>, NoUnwind,
115118
WillReturn]> {
116119
let results = (outs value:$result);
117-
let arguments = (ins (XdVectorType $result, any):$vector, I32:$index);
120+
let arguments = (ins (XdVectorType any, $result, any):$vector, I32:$index);
118121

119122
let summary = "extract an element from a vector";
120123
let description = [{
@@ -128,7 +131,7 @@ def InsertElementOp : ExampleOp<"insertelement", [Memory<[]>, NoUnwind,
128131
let arguments = (ins value:$vector, value:$value, I32:$index);
129132

130133
let verifier = [
131-
(XdVectorType $result, $value, any),
134+
(XdVectorType $result, any, $value, any),
132135
(eq $result, $vector),
133136
];
134137

@@ -141,9 +144,11 @@ def InsertElementOp : ExampleOp<"insertelement", [Memory<[]>, NoUnwind,
141144

142145
def FromFixedVectorOp : ExampleOp<"fromfixedvector", [Memory<[]>, NoUnwind,
143146
WillReturn]> {
144-
let results = (outs (XdVectorType $scalar_type, $num_elements):$result);
147+
let results = (outs (XdVectorType $kind, $scalar_type, $num_elements):$result);
145148
let arguments = (ins (FixedVectorType $scalar_type, $num_elements):$source);
146149

150+
let defaultBuilderHasExplicitResultType = true;
151+
147152
let summary = "convert <n x T> to our custom vector type";
148153
let description = [{
149154
Demonstrate a more complex unification case.

example/ExampleMain.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,11 @@ void createFunctionExample(Module &module, const Twine &name) {
9595
b.create<xd::WriteOp>(x4);
9696

9797
Value *q1 = b.create<xd::ReadOp>(FixedVectorType::get(b.getInt32Ty(), 2));
98-
Value *q2 = b.create<xd::FromFixedVectorOp>(q1);
98+
Value *q2 = b.create<xd::FromFixedVectorOp>(
99+
xd::XdVectorType::get(xd::VectorKind::BigEndian, b.getInt32Ty(), 2), q1);
99100

100-
Value *y1 = b.create<xd::ReadOp>(xd::XdVectorType::get(b.getInt32Ty(), 4));
101+
Value *y1 = b.create<xd::ReadOp>(
102+
xd::XdVectorType::get(xd::VectorKind::BigEndian, b.getInt32Ty(), 4));
101103
Value *y2 = b.create<xd::ExtractElementOp>(y1, x1);
102104
Value *y3 = b.create<xd::ExtractElementOp>(y1, b.getInt32(2));
103105
Value *y4 = b.CreateAdd(y2, y3);

include/llvm-dialects/Dialect/Dialect.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ class Attr<string cppType_> : MetaType {
5858

5959
// $0 is the LLVM value.
6060
string fromLlvmValue = ?;
61+
62+
// For use in dialect types. $0 is the C++ value
63+
string toUnsigned = "$0";
64+
65+
// For use in dialect types. $0 is the "type-erased" unsigned value
66+
string fromUnsigned = "$0";
6167
}
6268

6369
class IntegerAttr<string cppType_> : Attr<cppType_> {
@@ -235,6 +241,23 @@ def : AttrLlvmType<AttrI16, I16>;
235241
def : AttrLlvmType<AttrI32, I32>;
236242
def : AttrLlvmType<AttrI64, I64>;
237243

244+
// ============================================================================
245+
/// More general attributes
246+
// ============================================================================
247+
248+
// Define a custom enum attribute that can be used as in operations and dialect
249+
// types
250+
multiclass AttrEnum<string cppType_> {
251+
def NAME : Attr<cppType_> {
252+
let toLlvmValue = "::llvm::ConstantInt::get($1, static_cast<unsigned>($0))";
253+
let fromLlvmValue = "static_cast<" # cppType_ # ">(::llvm::cast<::llvm::ConstantInt>($0)->getZExtValue())";
254+
let toUnsigned = "static_cast<unsigned>($0)";
255+
let fromUnsigned = "static_cast<" # cppType_ # ">($0)";
256+
}
257+
258+
def : AttrLlvmType<!cast<Attr>(NAME), I32>;
259+
}
260+
238261
// ============================================================================
239262
/// Traits
240263
///

include/llvm-dialects/TableGen/Constraints.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ class Attr : public MetaType {
212212
llvm::Init *getLlvmType() const { return m_llvmType; }
213213
llvm::StringRef getToLlvmValue() const { return m_toLlvmValue; }
214214
llvm::StringRef getFromLlvmValue() const { return m_fromLlvmValue; }
215+
llvm::StringRef getToUnsigned() const { return m_toUnsigned; }
216+
llvm::StringRef getFromUnsigned() const { return m_fromUnsigned; }
215217

216218
// Set the LLVMType once -- used during initialization to break a circular
217219
// dependency in how IntegerType is defined.
@@ -227,6 +229,8 @@ class Attr : public MetaType {
227229
llvm::Init *m_llvmType = nullptr;
228230
std::string m_toLlvmValue;
229231
std::string m_fromLlvmValue;
232+
std::string m_toUnsigned;
233+
std::string m_fromUnsigned;
230234
};
231235

232236
} // namespace llvm_dialects

lib/TableGen/Constraints.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ std::unique_ptr<Attr> Attr::parse(raw_ostream &errs,
290290
attr->m_cppType = record->getValueAsString("cppType");
291291
attr->m_toLlvmValue = record->getValueAsString("toLlvmValue");
292292
attr->m_fromLlvmValue = record->getValueAsString("fromLlvmValue");
293+
attr->m_toUnsigned = record->getValueAsString("toUnsigned");
294+
attr->m_fromUnsigned = record->getValueAsString("fromUnsigned");
293295

294296
return attr;
295297
}

lib/TableGen/DialectType.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ void DialectType::emitDefinition(raw_ostream &out, GenDialect *dialect) const {
203203
++typeIdx;
204204
} else {
205205
expr = tgfmt("int_params()[$0]", &fmt, intIdx);
206+
expr = tgfmt(cast<Attr>(argument.type)->getFromUnsigned(), &fmt, expr);
206207
++intIdx;
207208
}
208209

@@ -247,8 +248,11 @@ void DialectType::emitDefinition(raw_ostream &out, GenDialect *dialect) const {
247248
&fmt, intIdx);
248249
for (const auto &[argument, getterArg] :
249250
llvm::zip(typeArguments(), getterArgs)) {
250-
if (!argument.type->isTypeArg())
251-
out << getterArg.name << ",\n";
251+
if (!argument.type->isTypeArg()) {
252+
std::string expr = tgfmt(cast<Attr>(argument.type)->getToUnsigned(), &fmt,
253+
getterArg.name);
254+
out << expr << ",\n";
255+
}
252256
}
253257

254258
out << tgfmt(R"(

test/example/generated/ExampleDialect.cpp.inc

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -178,23 +178,29 @@ XdHandleType* XdHandleType::get(::llvm::LLVMContext & ctx) {
178178
}
179179

180180

181+
VectorKind XdVectorType::getKind() const {
182+
return static_cast<VectorKind>(int_params()[0]);
183+
}
184+
185+
181186
::llvm::Type * XdVectorType::getElementType() const {
182187
return type_params()[0];
183188
}
184189

185190

186191
uint32_t XdVectorType::getNumElements() const {
187-
return int_params()[0];
192+
return int_params()[1];
188193
}
189194

190-
XdVectorType* XdVectorType::get(::llvm::Type * elementType, uint32_t numElements) {
195+
XdVectorType* XdVectorType::get(VectorKind kind, ::llvm::Type * elementType, uint32_t numElements) {
191196
::llvm::LLVMContext &ctx = elementType->getContext();
192197
::std::array<::llvm::Type *, 1> types = {
193198
elementType,
194199

195200
};
196-
::std::array<unsigned, 1> ints = {
197-
numElements,
201+
::std::array<unsigned, 2> ints = {
202+
static_cast<unsigned>(kind),
203+
numElements,
198204

199205
};
200206

@@ -218,13 +224,15 @@ elementType,
218224
return false;
219225
}
220226

221-
if (getNumIntParameters() != 1) {
227+
if (getNumIntParameters() != 2) {
222228
errs << " wrong number of int parameters\n";
223-
errs << " expected: 1\n";
229+
errs << " expected: 2\n";
224230
errs << " actual: " << getNumIntParameters() << '\n';
225231
return false;
226232
}
227-
auto elementType = getElementType();
233+
auto kind = getKind();
234+
(void)kind;
235+
auto elementType = getElementType();
228236
(void)elementType;
229237
auto numElements = getNumElements();
230238
(void)numElements;
@@ -519,7 +527,7 @@ index,
519527
}
520528

521529
if (!(::llvm::isa<XdVectorType>(vectorType))) {
522-
errs << " failed check for (XdVectorType ?:$result, any):$vector\n";
530+
errs << " failed check for (XdVectorType any, ?:$result, any):$vector\n";
523531

524532
errs << " with $vector = " << printable(vectorType) << '\n';
525533

@@ -550,7 +558,7 @@ index,
550558

551559
const ::llvm::StringLiteral FromFixedVectorOp::s_name{"xd.fromfixedvector"};
552560

553-
FromFixedVectorOp* FromFixedVectorOp::create(llvm_dialects::Builder& b, ::llvm::Value * source) {
561+
FromFixedVectorOp* FromFixedVectorOp::create(llvm_dialects::Builder& b, ::llvm::Type* resultType, ::llvm::Value * source) {
554562
::llvm::LLVMContext& context = b.getContext();
555563
::llvm::Module& module = *b.GetInsertBlock()->getModule();
556564

@@ -559,8 +567,8 @@ index,
559567
= ExampleDialect::get(context).getAttributeList(1);
560568

561569
std::string mangledName =
562-
::llvm_dialects::getMangledName(s_name, {XdVectorType::get(::llvm::cast<::llvm::FixedVectorType>(source->getType())->getElementType(), ::llvm::cast<::llvm::FixedVectorType>(source->getType())->getNumElements())});
563-
auto fnType = ::llvm::FunctionType::get(XdVectorType::get(::llvm::cast<::llvm::FixedVectorType>(source->getType())->getElementType(), ::llvm::cast<::llvm::FixedVectorType>(source->getType())->getNumElements()), true);
570+
::llvm_dialects::getMangledName(s_name, {resultType});
571+
auto fnType = ::llvm::FunctionType::get(resultType, true);
564572

565573
auto fn = module.getOrInsertFunction(mangledName, fnType, attrs);
566574
::llvm::SmallString<32> newName;
@@ -612,7 +620,7 @@ source,
612620
}
613621

614622
if (!(::llvm::isa<XdVectorType>(resultType))) {
615-
errs << " failed check for (XdVectorType ?:$scalar_type, ?:$num_elements):$result\n";
623+
errs << " failed check for (XdVectorType ?:$kind, ?:$scalar_type, ?:$num_elements):$result\n";
616624

617625
errs << " with $result = " << printable(resultType) << '\n';
618626

@@ -964,7 +972,7 @@ index,
964972
}
965973

966974
if (!(::llvm::isa<XdVectorType>(resultType))) {
967-
errs << " failed check for (XdVectorType ?:$result, ?:$value, any)\n";
975+
errs << " failed check for (XdVectorType ?:$result, any, ?:$value, any)\n";
968976

969977
errs << " with $result = " << printable(resultType) << '\n';
970978

test/example/generated/ExampleDialect.h.inc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ namespace xd {
7878

7979
bool verifier(::llvm::raw_ostream &errs) const;
8080

81-
static XdVectorType *get(::llvm::Type * elementType, uint32_t numElements);
81+
static XdVectorType *get(VectorKind kind, ::llvm::Type * elementType, uint32_t numElements);
8282

83+
VectorKind getKind() const;
8384
::llvm::Type * getElementType() const;
8485
uint32_t getNumElements() const;
8586
};
@@ -179,7 +180,7 @@ bool verifier(::llvm::raw_ostream &errs);
179180
return ::llvm::isa<::llvm::CallInst>(v) &&
180181
classof(::llvm::cast<::llvm::CallInst>(v));
181182
}
182-
static FromFixedVectorOp* create(::llvm_dialects::Builder& b, ::llvm::Value * source);
183+
static FromFixedVectorOp* create(::llvm_dialects::Builder& b, ::llvm::Type* resultType, ::llvm::Value * source);
183184

184185
bool verifier(::llvm::raw_ostream &errs);
185186

test/example/test-builder.test

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
; CHECK-NEXT: [[TMP5:%.*]] = call i64 (...) @xd.iext.i64(i32 [[TMP4]])
1414
; CHECK-NEXT: call void (...) @xd.write(i64 [[TMP5]])
1515
; CHECK-NEXT: [[TMP6:%.*]] = call <2 x i32> @xd.read.v2i32()
16-
; CHECK-NEXT: [[TMP7:%.*]] = call target("xd.vector", i32, 2) (...) @xd.fromfixedvector.txd.vector_i32_2t(<2 x i32> [[TMP6]])
17-
; CHECK-NEXT: [[TMP8:%.*]] = call target("xd.vector", i32, 4) @xd.read.txd.vector_i32_4t()
18-
; CHECK-NEXT: [[TMP9:%.*]] = call i32 (...) @xd.extractelement.i32(target("xd.vector", i32, 4) [[TMP8]], i32 [[TMP0]])
19-
; CHECK-NEXT: [[TMP10:%.*]] = call i32 (...) @xd.extractelement.i32(target("xd.vector", i32, 4) [[TMP8]], i32 2)
16+
; CHECK-NEXT: [[TMP7:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.fromfixedvector.txd.vector_i32_1_2t(<2 x i32> [[TMP6]])
17+
; CHECK-NEXT: [[TMP8:%.*]] = call target("xd.vector", i32, 1, 4) @xd.read.txd.vector_i32_1_4t()
18+
; CHECK-NEXT: [[TMP9:%.*]] = call i32 (...) @xd.extractelement.i32(target("xd.vector", i32, 1, 4) [[TMP8]], i32 [[TMP0]])
19+
; CHECK-NEXT: [[TMP10:%.*]] = call i32 (...) @xd.extractelement.i32(target("xd.vector", i32, 1, 4) [[TMP8]], i32 2)
2020
; CHECK-NEXT: [[TMP11:%.*]] = add i32 [[TMP9]], [[TMP10]]
21-
; CHECK-NEXT: [[TMP12:%.*]] = call target("xd.vector", i32, 2) (...) @xd.insertelement.txd.vector_i32_2t(target("xd.vector", i32, 2) [[TMP7]], i32 [[TMP11]], i32 [[TMP0]])
22-
; CHECK-NEXT: [[TMP13:%.*]] = call target("xd.vector", i32, 2) (...) @xd.insertelement.txd.vector_i32_2t(target("xd.vector", i32, 2) [[TMP12]], i32 [[TMP9]], i32 1)
23-
; CHECK-NEXT: call void (...) @xd.write(target("xd.vector", i32, 2) [[TMP13]])
21+
; CHECK-NEXT: [[TMP12:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.insertelement.txd.vector_i32_1_2t(target("xd.vector", i32, 1, 2) [[TMP7]], i32 [[TMP11]], i32 [[TMP0]])
22+
; CHECK-NEXT: [[TMP13:%.*]] = call target("xd.vector", i32, 1, 2) (...) @xd.insertelement.txd.vector_i32_1_2t(target("xd.vector", i32, 1, 2) [[TMP12]], i32 [[TMP9]], i32 1)
23+
; CHECK-NEXT: call void (...) @xd.write(target("xd.vector", i32, 1, 2) [[TMP13]])
2424
; CHECK-NEXT: [[TMP14:%.*]] = call ptr @xd.read.p0()
2525
; CHECK-NEXT: [[TMP15:%.*]] = call i8 (...) @xd.stream.add.i8(ptr [[TMP14]], i64 14, i8 0)
2626
; CHECK-NEXT: call void (...) @xd.write(i8 [[TMP15]])

0 commit comments

Comments
 (0)