Skip to content

Commit 303294d

Browse files
committed
tablegen: generate setter methods for operation arguments
Changing an operation's argument without recreating the operation is sometimes convenient and a small compile-time optimization.
1 parent 4258a6b commit 303294d

File tree

5 files changed

+206
-32
lines changed

5 files changed

+206
-32
lines changed

example/ExampleMain.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,16 @@ void createFunctionExample(Module &module, const Twine &name) {
8787
b.SetInsertPoint(bb);
8888

8989
Value *x1 = b.create<xd::ReadOp>(b.getInt32Ty());
90-
Value *sizeOf = b.create<xd::SizeOfOp>(b.getDoubleTy());
90+
Value *sizeOf = b.create<xd::SizeOfOp>(b.getHalfTy());
9191
Value *sizeOf32 = b.create<xd::ITruncOp>(b.getInt32Ty(), sizeOf);
92-
Value *x2 = b.create<xd::Add32Op>(x1, sizeOf32, 7);
92+
Value *x2 = b.create<xd::Add32Op>(x1, sizeOf32, 11);
9393
Value *x3 = b.create<xd::CombineOp>(x2, x1);
9494
Value *x4 = b.create<xd::IExtOp>(b.getInt64Ty(), x3);
9595
b.create<xd::WriteOp>(x4);
9696

97+
cast<xd::SizeOfOp>(sizeOf)->setSizeofType(b.getDoubleTy());
98+
cast<xd::Add32Op>(x2)->setExtra(7);
99+
97100
Value *q1 = b.create<xd::ReadOp>(FixedVectorType::get(b.getInt32Ty(), 2));
98101
Value *q2 = b.create<xd::FromFixedVectorOp>(
99102
xd::XdVectorType::get(xd::VectorKind::BigEndian, b.getInt32Ty(), 2), q1);
@@ -104,9 +107,11 @@ void createFunctionExample(Module &module, const Twine &name) {
104107
Value *y3 = b.create<xd::ExtractElementOp>(y1, b.getInt32(2));
105108
Value *y4 = b.CreateAdd(y2, y3);
106109
Value *y5 = b.create<xd::InsertElementOp>(q2, y4, x1);
107-
Value *y6 = b.create<xd::InsertElementOp>(y5, y2, b.getInt32(1));
110+
Value *y6 = b.create<xd::InsertElementOp>(y5, y2, b.getInt32(5));
108111
b.create<xd::WriteOp>(y6);
109112

113+
cast<xd::InsertElementOp>(y6)->setIndex(b.getInt32(1));
114+
110115
Value *p1 = b.create<xd::ReadOp>(b.getPtrTy(0));
111116
Value *p2 = b.create<xd::StreamAddOp>(p1, b.getInt64(14), b.getInt8(0));
112117
b.create<xd::WriteOp>(p2);

include/llvm-dialects/TableGen/Operations.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class OperationBase {
7474
/// List of arguments specific to this class; does not contain superclass
7575
/// arguments, if any.
7676
std::vector<NamedValue> m_arguments;
77+
78+
/// Attribute types as determined for the setter methods.
79+
std::vector<std::string> m_attrTypes;
7780
};
7881

7982
class OpClass : public OperationBase {

lib/TableGen/Operations.cpp

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525
using namespace llvm;
2626
using namespace llvm_dialects;
2727

28+
static std::string evaluateAttrLlvmType(raw_ostream &errs, raw_ostream &out,
29+
FmtContext &fmt, Attr *attr,
30+
StringRef name,
31+
GenDialectsContext &context,
32+
SymbolTable &symbols);
33+
2834
static std::optional<std::vector<NamedValue>>
2935
parseArguments(raw_ostream &errs, GenDialectsContext &context, Record *rec) {
3036
Record *superClassRec = rec->getValueAsDef("superclass");
@@ -60,6 +66,33 @@ bool OperationBase::init(raw_ostream &errs, GenDialectsContext &context,
6066
return false;
6167
m_arguments = std::move(*arguments);
6268

69+
for (const auto &arg : m_arguments) {
70+
std::string attrType;
71+
72+
if (auto *attr = dyn_cast<Attr>(arg.type)) {
73+
SymbolTable symbols;
74+
std::string prelude;
75+
raw_string_ostream preludes(prelude);
76+
FmtContext fmt;
77+
fmt.withContext("getContext()");
78+
79+
attrType = evaluateAttrLlvmType(errs, preludes, fmt, attr, arg.name,
80+
context, symbols);
81+
if (attrType.empty())
82+
return false;
83+
84+
if (!prelude.empty()) {
85+
errs << "got a non-empty prelude when determining the LLVM type of "
86+
<< arg.name << '\n';
87+
errs << prelude;
88+
errs << "this is currently unsupported\n";
89+
return false;
90+
}
91+
}
92+
93+
m_attrTypes.push_back(attrType);
94+
}
95+
6396
return true;
6497
}
6598

@@ -80,8 +113,12 @@ unsigned OperationBase::getNumFullArguments() const {
80113
void OperationBase::emitArgumentAccessorDeclarations(llvm::raw_ostream &out,
81114
FmtContext &fmt) const {
82115
for (const auto &arg : m_arguments) {
83-
out << tgfmt("$0 get$1();\n", &fmt, arg.type->getBuilderCppType(),
84-
convertToCamelFromSnakeCase(arg.name, true));
116+
out << tgfmt(R"(
117+
$0 get$1();
118+
void set$1($0 $2);
119+
)",
120+
&fmt, arg.type->getBuilderCppType(),
121+
convertToCamelFromSnakeCase(arg.name, true), arg.name);
85122
}
86123
}
87124

@@ -92,19 +129,40 @@ void OperationBase::emitArgumentAccessorDefinitions(llvm::raw_ostream &out,
92129
numSuperclassArgs = m_superclass->getNumFullArguments();
93130
for (auto indexedArg : llvm::enumerate(m_arguments)) {
94131
const NamedValue &arg = indexedArg.value();
95-
std::string value = llvm::formatv("getArgOperand({0})",
96-
numSuperclassArgs + indexedArg.index());
132+
FmtContextScope scope(fmt);
133+
fmt.withContext("getContext()");
134+
fmt.addSubst("index", Twine(numSuperclassArgs + indexedArg.index()));
135+
fmt.addSubst("cppType", arg.type->getBuilderCppType());
136+
fmt.addSubst("name", arg.name);
137+
fmt.addSubst("Name", convertToCamelFromSnakeCase(arg.name, true));
138+
139+
std::string fromLlvm = tgfmt("getArgOperand($index)", &fmt);
97140
if (auto *attr = dyn_cast<Attr>(arg.type))
98-
value = tgfmt(attr->getFromLlvmValue(), &fmt, value);
141+
fromLlvm = tgfmt(attr->getFromLlvmValue(), &fmt, fromLlvm);
99142
else if (arg.type->isTypeArg())
100-
value += "->getType()";
143+
fromLlvm += "->getType()";
144+
145+
std::string toLlvm = arg.name;
146+
if (auto *attr = dyn_cast<Attr>(arg.type)) {
147+
toLlvm = tgfmt(attr->getToLlvmValue(), &fmt, toLlvm,
148+
m_attrTypes[indexedArg.index()]);
149+
} else if (arg.type->isTypeArg()) {
150+
toLlvm = llvm::formatv("llvm::PoisonValue::get({0})", toLlvm);
151+
}
152+
153+
fmt.addSubst("fromLlvm", fromLlvm);
154+
fmt.addSubst("toLlvm", toLlvm);
155+
101156
out << tgfmt(R"(
102-
$0 $_op::get$1() {
103-
return $2;
157+
$cppType $_op::get$Name() {
158+
return $fromLlvm;
159+
}
160+
161+
void $_op::set$Name($cppType $name) {
162+
setArgOperand($index, $toLlvm);
104163
}
105164
)",
106-
&fmt, arg.type->getBuilderCppType(),
107-
convertToCamelFromSnakeCase(arg.name, true), value);
165+
&fmt);
108166
}
109167
}
110168

test/example/generated/ExampleDialect.cpp.inc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,26 @@ return true;
255255
::llvm::Value * StreamReduceOp::getPtr() {
256256
return getArgOperand(0);
257257
}
258+
259+
void StreamReduceOp::setPtr(::llvm::Value * ptr) {
260+
setArgOperand(0, ptr);
261+
}
258262

259263
::llvm::Value * StreamReduceOp::getCount() {
260264
return getArgOperand(1);
261265
}
266+
267+
void StreamReduceOp::setCount(::llvm::Value * count) {
268+
setArgOperand(1, count);
269+
}
262270

263271
::llvm::Value * StreamReduceOp::getInitial() {
264272
return getArgOperand(2);
265273
}
274+
275+
void StreamReduceOp::setInitial(::llvm::Value * initial) {
276+
setArgOperand(2, initial);
277+
}
266278

267279

268280
const ::llvm::StringLiteral Add32Op::s_name{"xd.add32"};
@@ -361,14 +373,26 @@ uint32_t const extra = getExtra();
361373
::llvm::Value * Add32Op::getLhs() {
362374
return getArgOperand(0);
363375
}
376+
377+
void Add32Op::setLhs(::llvm::Value * lhs) {
378+
setArgOperand(0, lhs);
379+
}
364380

365381
::llvm::Value * Add32Op::getRhs() {
366382
return getArgOperand(1);
367383
}
384+
385+
void Add32Op::setRhs(::llvm::Value * rhs) {
386+
setArgOperand(1, rhs);
387+
}
368388

369389
uint32_t Add32Op::getExtra() {
370390
return ::llvm::cast<::llvm::ConstantInt>(getArgOperand(2))->getZExtValue() ;
371391
}
392+
393+
void Add32Op::setExtra(uint32_t extra) {
394+
setArgOperand(2, ::llvm::ConstantInt::get(::llvm::IntegerType::get(getContext(), 32), extra) );
395+
}
372396

373397
::llvm::Value *Add32Op::getResult() {return this;}
374398

@@ -452,10 +476,18 @@ rhs,
452476
::llvm::Value * CombineOp::getLhs() {
453477
return getArgOperand(0);
454478
}
479+
480+
void CombineOp::setLhs(::llvm::Value * lhs) {
481+
setArgOperand(0, lhs);
482+
}
455483

456484
::llvm::Value * CombineOp::getRhs() {
457485
return getArgOperand(1);
458486
}
487+
488+
void CombineOp::setRhs(::llvm::Value * rhs) {
489+
setArgOperand(1, rhs);
490+
}
459491

460492
::llvm::Value *CombineOp::getResult() {return this;}
461493

@@ -547,10 +579,18 @@ index,
547579
::llvm::Value * ExtractElementOp::getVector() {
548580
return getArgOperand(0);
549581
}
582+
583+
void ExtractElementOp::setVector(::llvm::Value * vector) {
584+
setArgOperand(0, vector);
585+
}
550586

551587
::llvm::Value * ExtractElementOp::getIndex() {
552588
return getArgOperand(1);
553589
}
590+
591+
void ExtractElementOp::setIndex(::llvm::Value * index) {
592+
setArgOperand(1, index);
593+
}
554594

555595
::llvm::Value *ExtractElementOp::getResult() {return this;}
556596

@@ -700,6 +740,10 @@ source,
700740
::llvm::Value * FromFixedVectorOp::getSource() {
701741
return getArgOperand(0);
702742
}
743+
744+
void FromFixedVectorOp::setSource(::llvm::Value * source) {
745+
setArgOperand(0, source);
746+
}
703747

704748
::llvm::Value *FromFixedVectorOp::getResult() {return this;}
705749

@@ -852,6 +896,10 @@ source,
852896
::llvm::Value * IExtOp::getSource() {
853897
return getArgOperand(0);
854898
}
899+
900+
void IExtOp::setSource(::llvm::Value * source) {
901+
setArgOperand(0, source);
902+
}
855903

856904
::llvm::Value *IExtOp::getResult() {return this;}
857905

@@ -944,6 +992,10 @@ source,
944992
::llvm::Value * ITruncOp::getSource() {
945993
return getArgOperand(0);
946994
}
995+
996+
void ITruncOp::setSource(::llvm::Value * source) {
997+
setArgOperand(0, source);
998+
}
947999

9481000
::llvm::Value *ITruncOp::getResult() {return this;}
9491001

@@ -1045,14 +1097,26 @@ index,
10451097
::llvm::Value * InsertElementOp::getVector() {
10461098
return getArgOperand(0);
10471099
}
1100+
1101+
void InsertElementOp::setVector(::llvm::Value * vector) {
1102+
setArgOperand(0, vector);
1103+
}
10481104

10491105
::llvm::Value * InsertElementOp::getValue() {
10501106
return getArgOperand(1);
10511107
}
1108+
1109+
void InsertElementOp::setValue(::llvm::Value * value) {
1110+
setArgOperand(1, value);
1111+
}
10521112

10531113
::llvm::Value * InsertElementOp::getIndex() {
10541114
return getArgOperand(2);
10551115
}
1116+
1117+
void InsertElementOp::setIndex(::llvm::Value * index) {
1118+
setArgOperand(2, index);
1119+
}
10561120

10571121
::llvm::Value *InsertElementOp::getResult() {return this;}
10581122

@@ -1179,6 +1243,10 @@ index,
11791243
::llvm::Type * SizeOfOp::getSizeofType() {
11801244
return getArgOperand(0)->getType();
11811245
}
1246+
1247+
void SizeOfOp::setSizeofType(::llvm::Type * sizeof_type) {
1248+
setArgOperand(0, llvm::PoisonValue::get(sizeof_type));
1249+
}
11821250

11831251
::llvm::Value *SizeOfOp::getResult() {return this;}
11841252

@@ -1507,6 +1575,10 @@ data,
15071575
::llvm::Value * WriteOp::getData() {
15081576
return getArgOperand(0);
15091577
}
1578+
1579+
void WriteOp::setData(::llvm::Value * data) {
1580+
setArgOperand(0, data);
1581+
}
15101582

15111583

15121584

0 commit comments

Comments
 (0)