Skip to content

Commit c19a840

Browse files
committed
tablegen: generate setter methods for operation arguments
Changing an operation's argument without recreating the operation is sometimes convenient and a small compile-time optimization.
1 parent c6778cc commit c19a840

File tree

5 files changed

+206
-32
lines changed

5 files changed

+206
-32
lines changed

example/ExampleMain.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,16 @@ void createFunctionExample(Module &module, const Twine &name) {
8787
b.SetInsertPoint(bb);
8888

8989
Value *x1 = b.create<xd::ReadOp>(b.getInt32Ty());
90-
Value *sizeOf = b.create<xd::SizeOfOp>(b.getDoubleTy());
90+
Value *sizeOf = b.create<xd::SizeOfOp>(b.getHalfTy());
9191
Value *sizeOf32 = b.create<xd::ITruncOp>(b.getInt32Ty(), sizeOf);
92-
Value *x2 = b.create<xd::Add32Op>(x1, sizeOf32, 7);
92+
Value *x2 = b.create<xd::Add32Op>(x1, sizeOf32, 11);
9393
Value *x3 = b.create<xd::CombineOp>(x2, x1);
9494
Value *x4 = b.create<xd::IExtOp>(b.getInt64Ty(), x3);
9595
b.create<xd::WriteOp>(x4);
9696

97+
cast<xd::SizeOfOp>(sizeOf)->setSizeofType(b.getDoubleTy());
98+
cast<xd::Add32Op>(x2)->setExtra(7);
99+
97100
Value *q1 = b.create<xd::ReadOp>(FixedVectorType::get(b.getInt32Ty(), 2));
98101
Value *q2 = b.create<xd::FromFixedVectorOp>(
99102
xd::XdVectorType::get(xd::VectorKind::BigEndian, b.getInt32Ty(), 2), q1);
@@ -104,9 +107,11 @@ void createFunctionExample(Module &module, const Twine &name) {
104107
Value *y3 = b.create<xd::ExtractElementOp>(y1, b.getInt32(2));
105108
Value *y4 = b.CreateAdd(y2, y3);
106109
Value *y5 = b.create<xd::InsertElementOp>(q2, y4, x1);
107-
Value *y6 = b.create<xd::InsertElementOp>(y5, y2, b.getInt32(1));
110+
Value *y6 = b.create<xd::InsertElementOp>(y5, y2, b.getInt32(5));
108111
b.create<xd::WriteOp>(y6);
109112

113+
cast<xd::InsertElementOp>(y6)->setIndex(b.getInt32(1));
114+
110115
Value *p1 = b.create<xd::ReadOp>(b.getPtrTy(0));
111116
Value *p2 = b.create<xd::StreamAddOp>(p1, b.getInt64(14), b.getInt8(0));
112117
b.create<xd::WriteOp>(p2);

include/llvm-dialects/TableGen/Operations.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class OperationBase {
7474
/// List of arguments specific to this class; does not contain superclass
7575
/// arguments, if any.
7676
std::vector<NamedValue> m_arguments;
77+
78+
/// Attribute types as determined for the setter methods.
79+
std::vector<std::string> m_attrTypes;
7780
};
7881

7982
class OpClass : public OperationBase {

lib/TableGen/Operations.cpp

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525
using namespace llvm;
2626
using namespace llvm_dialects;
2727

28+
static std::string evaluateAttrLlvmType(raw_ostream &errs, raw_ostream &out,
29+
FmtContext &fmt, Attr *attr,
30+
StringRef name,
31+
GenDialectsContext &context,
32+
SymbolTable &symbols);
33+
2834
static std::optional<std::vector<NamedValue>>
2935
parseArguments(raw_ostream &errs, GenDialectsContext &context, Record *rec) {
3036
Record *superClassRec = rec->getValueAsDef("superclass");
@@ -60,6 +66,33 @@ bool OperationBase::init(raw_ostream &errs, GenDialectsContext &context,
6066
return false;
6167
m_arguments = std::move(*arguments);
6268

69+
for (const auto &arg : m_arguments) {
70+
std::string attrType;
71+
72+
if (auto *attr = dyn_cast<Attr>(arg.type)) {
73+
SymbolTable symbols;
74+
std::string prelude;
75+
raw_string_ostream preludes(prelude);
76+
FmtContext fmt;
77+
fmt.withContext("getContext()");
78+
79+
attrType = evaluateAttrLlvmType(errs, preludes, fmt, attr, arg.name,
80+
context, symbols);
81+
if (attrType.empty())
82+
return false;
83+
84+
if (!prelude.empty()) {
85+
errs << "got a non-empty prelude when determining the LLVM type of "
86+
<< arg.name << '\n';
87+
errs << prelude;
88+
errs << "this is currently unsupported\n";
89+
return false;
90+
}
91+
}
92+
93+
m_attrTypes.push_back(attrType);
94+
}
95+
6396
return true;
6497
}
6598

@@ -80,8 +113,12 @@ unsigned OperationBase::getNumFullArguments() const {
80113
void OperationBase::emitArgumentAccessorDeclarations(llvm::raw_ostream &out,
81114
FmtContext &fmt) const {
82115
for (const auto &arg : m_arguments) {
83-
out << tgfmt("$0 get$1();\n", &fmt, arg.type->getBuilderCppType(),
84-
convertToCamelFromSnakeCase(arg.name, true));
116+
out << tgfmt(R"(
117+
$0 get$1();
118+
void set$1($0 $2);
119+
)",
120+
&fmt, arg.type->getBuilderCppType(),
121+
convertToCamelFromSnakeCase(arg.name, true), arg.name);
85122
}
86123
}
87124

@@ -92,19 +129,40 @@ void OperationBase::emitArgumentAccessorDefinitions(llvm::raw_ostream &out,
92129
numSuperclassArgs = m_superclass->getNumFullArguments();
93130
for (auto indexedArg : llvm::enumerate(m_arguments)) {
94131
const NamedValue &arg = indexedArg.value();
95-
std::string value = llvm::formatv("getArgOperand({0})",
96-
numSuperclassArgs + indexedArg.index());
132+
FmtContextScope scope(fmt);
133+
fmt.withContext("getContext()");
134+
fmt.addSubst("index", Twine(numSuperclassArgs + indexedArg.index()));
135+
fmt.addSubst("cppType", arg.type->getBuilderCppType());
136+
fmt.addSubst("name", arg.name);
137+
fmt.addSubst("Name", convertToCamelFromSnakeCase(arg.name, true));
138+
139+
std::string fromLlvm = tgfmt("getArgOperand($index)", &fmt);
97140
if (auto *attr = dyn_cast<Attr>(arg.type))
98-
value = tgfmt(attr->getFromLlvmValue(), &fmt, value);
141+
fromLlvm = tgfmt(attr->getFromLlvmValue(), &fmt, fromLlvm);
99142
else if (arg.type->isTypeArg())
100-
value += "->getType()";
143+
fromLlvm += "->getType()";
144+
145+
std::string toLlvm = arg.name;
146+
if (auto *attr = dyn_cast<Attr>(arg.type)) {
147+
toLlvm = tgfmt(attr->getToLlvmValue(), &fmt, toLlvm,
148+
m_attrTypes[indexedArg.index()]);
149+
} else if (arg.type->isTypeArg()) {
150+
toLlvm = llvm::formatv("llvm::PoisonValue::get({0})", toLlvm);
151+
}
152+
153+
fmt.addSubst("fromLlvm", fromLlvm);
154+
fmt.addSubst("toLlvm", toLlvm);
155+
101156
out << tgfmt(R"(
102-
$0 $_op::get$1() {
103-
return $2;
157+
$cppType $_op::get$Name() {
158+
return $fromLlvm;
159+
}
160+
161+
void $_op::set$Name($cppType $name) {
162+
setArgOperand($index, $toLlvm);
104163
}
105164
)",
106-
&fmt, arg.type->getBuilderCppType(),
107-
convertToCamelFromSnakeCase(arg.name, true), value);
165+
&fmt);
108166
}
109167
}
110168

test/example/generated/ExampleDialect.cpp.inc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,26 @@ return true;
257257
::llvm::Value * StreamReduceOp::getPtr() {
258258
return getArgOperand(0);
259259
}
260+
261+
void StreamReduceOp::setPtr(::llvm::Value * ptr) {
262+
setArgOperand(0, ptr);
263+
}
260264

261265
::llvm::Value * StreamReduceOp::getCount() {
262266
return getArgOperand(1);
263267
}
268+
269+
void StreamReduceOp::setCount(::llvm::Value * count) {
270+
setArgOperand(1, count);
271+
}
264272

265273
::llvm::Value * StreamReduceOp::getInitial() {
266274
return getArgOperand(2);
267275
}
276+
277+
void StreamReduceOp::setInitial(::llvm::Value * initial) {
278+
setArgOperand(2, initial);
279+
}
268280

269281

270282
const ::llvm::StringLiteral Add32Op::s_name{"xd.add32"};
@@ -364,14 +376,26 @@ uint32_t const extra = getExtra();
364376
::llvm::Value * Add32Op::getLhs() {
365377
return getArgOperand(0);
366378
}
379+
380+
void Add32Op::setLhs(::llvm::Value * lhs) {
381+
setArgOperand(0, lhs);
382+
}
367383

368384
::llvm::Value * Add32Op::getRhs() {
369385
return getArgOperand(1);
370386
}
387+
388+
void Add32Op::setRhs(::llvm::Value * rhs) {
389+
setArgOperand(1, rhs);
390+
}
371391

372392
uint32_t Add32Op::getExtra() {
373393
return ::llvm::cast<::llvm::ConstantInt>(getArgOperand(2))->getZExtValue() ;
374394
}
395+
396+
void Add32Op::setExtra(uint32_t extra) {
397+
setArgOperand(2, ::llvm::ConstantInt::get(::llvm::IntegerType::get(getContext(), 32), extra) );
398+
}
375399

376400
::llvm::Value *Add32Op::getResult() {return this;}
377401

@@ -455,10 +479,18 @@ rhs,
455479
::llvm::Value * CombineOp::getLhs() {
456480
return getArgOperand(0);
457481
}
482+
483+
void CombineOp::setLhs(::llvm::Value * lhs) {
484+
setArgOperand(0, lhs);
485+
}
458486

459487
::llvm::Value * CombineOp::getRhs() {
460488
return getArgOperand(1);
461489
}
490+
491+
void CombineOp::setRhs(::llvm::Value * rhs) {
492+
setArgOperand(1, rhs);
493+
}
462494

463495
::llvm::Value *CombineOp::getResult() {return this;}
464496

@@ -550,10 +582,18 @@ index,
550582
::llvm::Value * ExtractElementOp::getVector() {
551583
return getArgOperand(0);
552584
}
585+
586+
void ExtractElementOp::setVector(::llvm::Value * vector) {
587+
setArgOperand(0, vector);
588+
}
553589

554590
::llvm::Value * ExtractElementOp::getIndex() {
555591
return getArgOperand(1);
556592
}
593+
594+
void ExtractElementOp::setIndex(::llvm::Value * index) {
595+
setArgOperand(1, index);
596+
}
557597

558598
::llvm::Value *ExtractElementOp::getResult() {return this;}
559599

@@ -703,6 +743,10 @@ source,
703743
::llvm::Value * FromFixedVectorOp::getSource() {
704744
return getArgOperand(0);
705745
}
746+
747+
void FromFixedVectorOp::setSource(::llvm::Value * source) {
748+
setArgOperand(0, source);
749+
}
706750

707751
::llvm::Value *FromFixedVectorOp::getResult() {return this;}
708752

@@ -855,6 +899,10 @@ source,
855899
::llvm::Value * IExtOp::getSource() {
856900
return getArgOperand(0);
857901
}
902+
903+
void IExtOp::setSource(::llvm::Value * source) {
904+
setArgOperand(0, source);
905+
}
858906

859907
::llvm::Value *IExtOp::getResult() {return this;}
860908

@@ -947,6 +995,10 @@ source,
947995
::llvm::Value * ITruncOp::getSource() {
948996
return getArgOperand(0);
949997
}
998+
999+
void ITruncOp::setSource(::llvm::Value * source) {
1000+
setArgOperand(0, source);
1001+
}
9501002

9511003
::llvm::Value *ITruncOp::getResult() {return this;}
9521004

@@ -1048,14 +1100,26 @@ index,
10481100
::llvm::Value * InsertElementOp::getVector() {
10491101
return getArgOperand(0);
10501102
}
1103+
1104+
void InsertElementOp::setVector(::llvm::Value * vector) {
1105+
setArgOperand(0, vector);
1106+
}
10511107

10521108
::llvm::Value * InsertElementOp::getValue() {
10531109
return getArgOperand(1);
10541110
}
1111+
1112+
void InsertElementOp::setValue(::llvm::Value * value) {
1113+
setArgOperand(1, value);
1114+
}
10551115

10561116
::llvm::Value * InsertElementOp::getIndex() {
10571117
return getArgOperand(2);
10581118
}
1119+
1120+
void InsertElementOp::setIndex(::llvm::Value * index) {
1121+
setArgOperand(2, index);
1122+
}
10591123

10601124
::llvm::Value *InsertElementOp::getResult() {return this;}
10611125

@@ -1182,6 +1246,10 @@ index,
11821246
::llvm::Type * SizeOfOp::getSizeofType() {
11831247
return getArgOperand(0)->getType();
11841248
}
1249+
1250+
void SizeOfOp::setSizeofType(::llvm::Type * sizeof_type) {
1251+
setArgOperand(0, llvm::PoisonValue::get(sizeof_type));
1252+
}
11851253

11861254
::llvm::Value *SizeOfOp::getResult() {return this;}
11871255

@@ -1510,6 +1578,10 @@ data,
15101578
::llvm::Value * WriteOp::getData() {
15111579
return getArgOperand(0);
15121580
}
1581+
1582+
void WriteOp::setData(::llvm::Value * data) {
1583+
setArgOperand(0, data);
1584+
}
15131585

15141586

15151587

0 commit comments

Comments
 (0)