Skip to content

Commit 99d8590

Browse files
authored
[mlir] [irdl] Add support for regions in irdl-to-cpp (#158540)
Fixes #158034 For the input ```mlir irdl.dialect @conditional_dialect { // A conditional operation with regions irdl.operation @conditional { // Create region constraints %r0 = irdl.region // Unconstrained region %r1 = irdl.region() // Region with no entry block arguments %v0 = irdl.any %r2 = irdl.region(%v0) // Region with one i1 entry block argument irdl.regions(cond: %r2, then: %r0, else: %r1) } } ``` This produces the following cpp: https://gist.github.com/j2kun/d2095f108efbd8d403576d5c460e0c00 Summary of changes: - The op class and adaptor get named accessors to the regions `Region &get<RegionName>()` and `getRegions()` - The op now gets `OpTrait::NRegions<3>` and `OpInvariants` to trigger the region verification - Support for region block argument constraints is added, but not working for all constraints until codegen for `irdl.is` is added (filed #161018 and left a TODO). - Helper functions for the individual verification steps are added, following mlir-tblgen's format (in the above gist, `__mlir_irdl_local_region_constraint_ConditionalOp_cond` and similar), and `verifyInvariantsImpl` that calls them. - Regions are added in the builder ## Questions for the reviewer ### What is the "correct" interface for verification? I used `mlir-tblgen` on an analogous version of the example `ConditionalOp` in this PR, and I see an `::mlir::OpTrait::OpInvariants` trait as well as ```cpp ::llvm::LogicalResult ConditionalOp::verifyInvariantsImpl() { { unsigned index = 0; (void)index; for (auto &region : ::llvm::MutableArrayRef((*this)->getRegion(0))) if (::mlir::failed(__mlir_ods_local_region_constraint_test1(*this, region, "cond", index++))) return ::mlir::failure(); for (auto &region : ::llvm::MutableArrayRef((*this)->getRegion(1))) if (::mlir::failed(__mlir_ods_local_region_constraint_test1(*this, region, "then", index++))) return ::mlir::failure(); for (auto &region : ::llvm::MutableArrayRef((*this)->getRegion(2))) if (::mlir::failed(__mlir_ods_local_region_constraint_test1(*this, region, "else", index++))) return ::mlir::failure(); } return ::mlir::success(); } ::llvm::LogicalResult ConditionalOp::verifyInvariants() { if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) return ::mlir::success(); return ::mlir::failure(); } ``` However, `OpInvariants` only seems to need `verifyInvariantsImpl`, so it's not clear to me what is the purpose of the `verifyInvariants` function, or, if I leave out `verifyInvariants`, whether I need to call `verify()` in my implementation of `verifyInvariantsImpl`. In this PR, I omitted `verifyInvariants` and generated `verifyInvariantsImpl`. ### Is testing sufficient? I am not certain I implemented the builders properly, and it's unclear to me to what extent the existing tests check this (which look like they compile the generated cpp, but don't actually use it). Did I omit some standard function or overload? --------- Co-authored-by: Jeremy Kun <[email protected]>
1 parent 67c000e commit 99d8590

File tree

8 files changed

+299
-59
lines changed

8 files changed

+299
-59
lines changed

mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ struct OpStrings {
5454
std::string opCppName;
5555
SmallVector<std::string> opResultNames;
5656
SmallVector<std::string> opOperandNames;
57+
SmallVector<std::string> opRegionNames;
5758
};
5859

5960
static std::string joinNameList(llvm::ArrayRef<std::string> names) {
@@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) {
8788
/// Generates OpStrings from an OperatioOp
8889
static OpStrings getStrings(irdl::OperationOp op) {
8990
auto operandOp = op.getOp<irdl::OperandsOp>();
90-
9191
auto resultOp = op.getOp<irdl::ResultsOp>();
92+
auto regionsOp = op.getOp<irdl::RegionsOp>();
9293

9394
OpStrings strings;
9495
strings.opName = op.getSymName();
@@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) {
108109
}));
109110
}
110111

112+
if (regionsOp) {
113+
strings.opRegionNames = SmallVector<std::string>(
114+
llvm::map_range(regionsOp->getNames(), [](Attribute attr) {
115+
return llvm::formatv("{0}", cast<StringAttr>(attr));
116+
}));
117+
}
118+
111119
return strings;
112120
}
113121

@@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict,
122130
static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
123131
const auto operandCount = strings.opOperandNames.size();
124132
const auto resultCount = strings.opResultNames.size();
133+
const auto regionCount = strings.opRegionNames.size();
125134

126135
dict["OP_NAME"] = strings.opName;
127136
dict["OP_CPP_NAME"] = strings.opCppName;
@@ -131,6 +140,7 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
131140
operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}";
132141
dict["OP_RESULT_INITIALIZER_LIST"] =
133142
resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
143+
dict["OP_REGION_COUNT"] = std::to_string(regionCount);
134144
}
135145

136146
/// Fills a dictionary with values from DialectStrings
@@ -179,6 +189,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
179189
const OpStrings &opStrings) {
180190
auto opGetters = std::string{};
181191
auto resGetters = std::string{};
192+
auto regionGetters = std::string{};
193+
auto regionAdaptorGetters = std::string{};
182194

183195
for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
184196
const auto op =
@@ -196,8 +208,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
196208
op, i);
197209
}
198210

211+
for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) {
212+
const auto op =
213+
llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true);
214+
regionAdaptorGetters += llvm::formatv(
215+
R"(::mlir::Region &get{0}() { return *getRegions()[{1}]; }
216+
)",
217+
op, i);
218+
regionGetters += llvm::formatv(
219+
R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
220+
)",
221+
op, i);
222+
}
223+
199224
dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
200225
dict["OP_RESULT_GETTER_DECLS"] = resGetters;
226+
dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters;
227+
dict["OP_REGION_GETTER_DECLS"] = regionGetters;
201228
}
202229

203230
static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
@@ -238,6 +265,22 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
238265
dict["OP_BUILD_DECLS"] = buildDecls;
239266
}
240267

268+
// add traits to the dictionary, return true if any were added
269+
static SmallVector<std::string> generateTraits(irdl::OperationOp op,
270+
const OpStrings &strings) {
271+
SmallVector<std::string> cppTraitNames;
272+
if (!strings.opRegionNames.empty()) {
273+
cppTraitNames.push_back(
274+
llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
275+
strings.opRegionNames.size())
276+
.str());
277+
278+
// Requires verifyInvariantsImpl is implemented on the op
279+
cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants");
280+
}
281+
return cppTraitNames;
282+
}
283+
241284
static LogicalResult generateOperationInclude(irdl::OperationOp op,
242285
raw_ostream &output,
243286
irdl::detail::dictionary &dict) {
@@ -247,6 +290,13 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op,
247290
const auto opStrings = getStrings(op);
248291
fillDict(dict, opStrings);
249292

293+
SmallVector<std::string> traitNames = generateTraits(op, opStrings);
294+
if (traitNames.empty())
295+
dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
296+
else
297+
dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName,
298+
llvm::join(traitNames, ", "));
299+
250300
generateOpGetterDeclarations(dict, opStrings);
251301
generateOpBuilderDeclarations(dict, opStrings);
252302

@@ -301,6 +351,110 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
301351
return success();
302352
}
303353

354+
static void generateRegionConstraintVerifiers(
355+
irdl::detail::dictionary &dict, irdl::OperationOp op,
356+
const OpStrings &strings, SmallVectorImpl<std::string> &verifierHelpers,
357+
SmallVectorImpl<std::string> &verifierCalls) {
358+
auto regionsOp = op.getOp<irdl::RegionsOp>();
359+
if (strings.opRegionNames.empty() || !regionsOp)
360+
return;
361+
362+
for (size_t i = 0; i < strings.opRegionNames.size(); ++i) {
363+
std::string regionName = strings.opRegionNames[i];
364+
std::string helperFnName =
365+
llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
366+
strings.opCppName, regionName)
367+
.str();
368+
369+
// Extract the actual region constraint from the IRDL RegionOp
370+
std::string condition = "true";
371+
std::string textualConditionName = "any region";
372+
373+
if (auto regionDefOp =
374+
dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) {
375+
// Generate constraint condition based on RegionOp attributes
376+
SmallVector<std::string> conditionParts;
377+
SmallVector<std::string> descriptionParts;
378+
379+
// Check number of blocks constraint
380+
if (auto blockCount = regionDefOp.getNumberOfBlocks()) {
381+
conditionParts.push_back(
382+
llvm::formatv("region.getBlocks().size() == {0}",
383+
blockCount.value())
384+
.str());
385+
descriptionParts.push_back(
386+
llvm::formatv("exactly {0} block(s)", blockCount.value()).str());
387+
}
388+
389+
// Check entry block arguments constraint
390+
if (regionDefOp.getConstrainedArguments()) {
391+
size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size();
392+
conditionParts.push_back(
393+
llvm::formatv("region.getNumArguments() == {0}", expectedArgCount)
394+
.str());
395+
descriptionParts.push_back(
396+
llvm::formatv("{0} entry block argument(s)", expectedArgCount)
397+
.str());
398+
}
399+
400+
// Combine conditions
401+
if (!conditionParts.empty()) {
402+
condition = llvm::join(conditionParts, " && ");
403+
}
404+
405+
// Generate descriptive error message
406+
if (!descriptionParts.empty()) {
407+
textualConditionName =
408+
llvm::formatv("region with {0}",
409+
llvm::join(descriptionParts, " and "))
410+
.str();
411+
}
412+
}
413+
414+
verifierHelpers.push_back(llvm::formatv(
415+
R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName, unsigned regionIndex) {{
416+
if (!({1})) {{
417+
return op->emitOpError("region #") << regionIndex
418+
<< (regionName.empty() ? " " : " ('" + regionName + "') ")
419+
<< "failed to verify constraint: {2}";
420+
}
421+
return ::mlir::success();
422+
})",
423+
helperFnName, condition, textualConditionName));
424+
425+
verifierCalls.push_back(llvm::formatv(R"(
426+
if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1})))
427+
return ::mlir::failure();)",
428+
helperFnName, i, regionName)
429+
.str());
430+
}
431+
}
432+
433+
static void generateVerifiers(irdl::detail::dictionary &dict,
434+
irdl::OperationOp op, const OpStrings &strings) {
435+
SmallVector<std::string> verifierHelpers;
436+
SmallVector<std::string> verifierCalls;
437+
438+
generateRegionConstraintVerifiers(dict, op, strings, verifierHelpers,
439+
verifierCalls);
440+
441+
// Add an overall verifier that sequences the helper calls
442+
std::string verifierDef =
443+
llvm::formatv(R"(
444+
::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
445+
if(::mlir::failed(verify()))
446+
return ::mlir::failure();
447+
448+
{1}
449+
450+
return ::mlir::success();
451+
})",
452+
strings.opCppName, llvm::join(verifierCalls, "\n"));
453+
454+
dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n");
455+
dict["OP_VERIFIER"] = verifierDef;
456+
}
457+
304458
static std::string generateOpDefinition(irdl::detail::dictionary &dict,
305459
irdl::OperationOp op) {
306460
static const auto perOpDefTemplate = mlir::irdl::detail::Template{
@@ -370,6 +524,8 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {
370524

371525
dict["OP_BUILD_DEFS"] = buildDefinition;
372526

527+
generateVerifiers(dict, op, opStrings);
528+
373529
std::string str;
374530
llvm::raw_string_ostream stream{str};
375531
perOpDefTemplate.render(stream, dict);
@@ -427,7 +583,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
427583
dict["TYPE_PARSER"] = llvm::formatv(
428584
R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
429585
return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
430-
{0}
586+
{0}
431587
.Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
432588
*mnemonic = keyword;
433589
return std::nullopt;
@@ -520,6 +676,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
520676
"IRDL C++ translation does not yet support variadic results");
521677
}))
522678
.Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
679+
.Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); }))
680+
.Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); }))
523681
.Default([](mlir::Operation *op) -> LogicalResult {
524682
return op->emitError("IRDL C++ translation does not yet support "
525683
"translation of ")

mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ public:
1212
struct Properties {
1313
};
1414
public:
15-
__OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
16-
: odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
17-
odsRegions(op->getRegions())
15+
__OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
16+
: odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
17+
odsRegions(op->getRegions())
1818
{}
1919

2020
/// Return the unstructured operand index of a structured operand along with
2121
// the amount of unstructured operands it contains.
2222
std::pair<unsigned, unsigned>
23-
getStructuredOperandIndexAndLength (unsigned index,
23+
getStructuredOperandIndexAndLength (unsigned index,
2424
unsigned odsOperandsSize) {
2525
return {index, 1};
2626
}
@@ -32,6 +32,12 @@ public:
3232
::mlir::DictionaryAttr getAttributes() {
3333
return odsAttrs;
3434
}
35+
36+
__OP_REGION_ADAPTER_GETTER_DECLS__
37+
38+
::mlir::RegionRange getRegions() {
39+
return odsRegions;
40+
}
3541
protected:
3642
::mlir::DictionaryAttr odsAttrs;
3743
::std::optional<::mlir::OperationName> odsOpName;
@@ -42,28 +48,28 @@ protected:
4248
} // namespace detail
4349

4450
template <typename RangeT>
45-
class __OP_CPP_NAME__GenericAdaptor
51+
class __OP_CPP_NAME__GenericAdaptor
4652
: public detail::__OP_CPP_NAME__GenericAdaptorBase {
4753
using ValueT = ::llvm::detail::ValueOfRange<RangeT>;
4854
using Base = detail::__OP_CPP_NAME__GenericAdaptorBase;
4955
public:
5056
__OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs,
51-
::mlir::OpaqueProperties properties,
52-
::mlir::RegionRange regions = {})
53-
: __OP_CPP_NAME__GenericAdaptor(values, attrs,
54-
(properties ? *properties.as<::mlir::EmptyProperties *>()
57+
::mlir::OpaqueProperties properties,
58+
::mlir::RegionRange regions = {})
59+
: __OP_CPP_NAME__GenericAdaptor(values, attrs,
60+
(properties ? *properties.as<::mlir::EmptyProperties *>()
5561
: ::mlir::EmptyProperties{}), regions) {}
5662

57-
__OP_CPP_NAME__GenericAdaptor(RangeT values,
63+
__OP_CPP_NAME__GenericAdaptor(RangeT values,
5864
const __OP_CPP_NAME__GenericAdaptorBase &base)
5965
: Base(base), odsOperands(values) {}
6066

61-
// This template parameter allows using __OP_CPP_NAME__ which is declared
67+
// This template parameter allows using __OP_CPP_NAME__ which is declared
6268
// later.
6369
template <typename LateInst = __OP_CPP_NAME__,
6470
typename = std::enable_if_t<
6571
std::is_same_v<LateInst, __OP_CPP_NAME__>>>
66-
__OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
72+
__OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
6773
: Base(op), odsOperands(values) {}
6874

6975
/// Return the unstructured operand index of a structured operand along with
@@ -77,7 +83,7 @@ public:
7783
RangeT getStructuredOperands(unsigned index) {
7884
auto valueRange = getStructuredOperandIndexAndLength(index);
7985
return {std::next(odsOperands.begin(), valueRange.first),
80-
std::next(odsOperands.begin(),
86+
std::next(odsOperands.begin(),
8187
valueRange.first + valueRange.second)};
8288
}
8389

@@ -91,7 +97,7 @@ private:
9197
RangeT odsOperands;
9298
};
9399

94-
class __OP_CPP_NAME__Adaptor
100+
class __OP_CPP_NAME__Adaptor
95101
: public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> {
96102
public:
97103
using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor;
@@ -100,7 +106,7 @@ public:
100106
::llvm::LogicalResult verify(::mlir::Location loc);
101107
};
102108

103-
class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> {
109+
class __OP_CPP_NAME__ : public ::mlir::Op<__OP_TEMPLATE_ARGS__> {
104110
public:
105111
using Op::Op;
106112
using Op::print;
@@ -112,6 +118,8 @@ public:
112118
return {};
113119
}
114120

121+
::llvm::LogicalResult verifyInvariantsImpl();
122+
115123
static constexpr ::llvm::StringLiteral getOperationName() {
116124
return ::llvm::StringLiteral("__DIALECT_NAME__.__OP_NAME__");
117125
}
@@ -147,7 +155,7 @@ public:
147155
::mlir::Operation::operand_range getStructuredOperands(unsigned index) {
148156
auto valueRange = getStructuredOperandIndexAndLength(index);
149157
return {std::next(getOperation()->operand_begin(), valueRange.first),
150-
std::next(getOperation()->operand_begin(),
158+
std::next(getOperation()->operand_begin(),
151159
valueRange.first + valueRange.second)};
152160
}
153161

@@ -162,18 +170,19 @@ public:
162170
::mlir::Operation::result_range getStructuredResults(unsigned index) {
163171
auto valueRange = getStructuredResultIndexAndLength(index);
164172
return {std::next(getOperation()->result_begin(), valueRange.first),
165-
std::next(getOperation()->result_begin(),
173+
std::next(getOperation()->result_begin(),
166174
valueRange.first + valueRange.second)};
167175
}
168176

169177
__OP_OPERAND_GETTER_DECLS__
170178
__OP_RESULT_GETTER_DECLS__
171-
179+
__OP_REGION_GETTER_DECLS__
180+
172181
__OP_BUILD_DECLS__
173-
static void build(::mlir::OpBuilder &odsBuilder,
174-
::mlir::OperationState &odsState,
175-
::mlir::TypeRange resultTypes,
176-
::mlir::ValueRange operands,
182+
static void build(::mlir::OpBuilder &odsBuilder,
183+
::mlir::OperationState &odsState,
184+
::mlir::TypeRange resultTypes,
185+
::mlir::ValueRange operands,
177186
::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
178187

179188
static __OP_CPP_NAME__ create(::mlir::OpBuilder &odsBuilder,

0 commit comments

Comments
 (0)