Skip to content

Commit 86ff4a7

Browse files
committed
Add support for regions in irdl-to-cpp
1 parent 17f6888 commit 86ff4a7

File tree

5 files changed

+259
-53
lines changed

5 files changed

+259
-53
lines changed

mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp

Lines changed: 170 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ struct OpStrings {
5454
std::string opCppName;
5555
SmallVector<std::string> opResultNames;
5656
SmallVector<std::string> opOperandNames;
57+
SmallVector<std::string> opRegionNames;
5758
};
5859

5960
static std::string joinNameList(llvm::ArrayRef<std::string> names) {
@@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) {
8788
/// Generates OpStrings from an OperatioOp
8889
static OpStrings getStrings(irdl::OperationOp op) {
8990
auto operandOp = op.getOp<irdl::OperandsOp>();
90-
9191
auto resultOp = op.getOp<irdl::ResultsOp>();
92+
auto regionsOp = op.getOp<irdl::RegionsOp>();
9293

9394
OpStrings strings;
9495
strings.opName = op.getSymName();
@@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) {
108109
}));
109110
}
110111

112+
if (regionsOp) {
113+
strings.opRegionNames = SmallVector<std::string>(
114+
llvm::map_range(regionsOp->getNames(), [](Attribute attr) {
115+
return llvm::formatv("{0}", cast<StringAttr>(attr));
116+
}));
117+
}
118+
111119
return strings;
112120
}
113121

@@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict,
122130
static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
123131
const auto operandCount = strings.opOperandNames.size();
124132
const auto resultCount = strings.opResultNames.size();
133+
const auto regionCount = strings.opRegionNames.size();
125134

126135
dict["OP_NAME"] = strings.opName;
127136
dict["OP_CPP_NAME"] = strings.opCppName;
@@ -131,6 +140,15 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
131140
operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}";
132141
dict["OP_RESULT_INITIALIZER_LIST"] =
133142
resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
143+
dict["OP_REGION_COUNT"] = std::to_string(regionCount);
144+
dict["OP_ADD_REGIONS"] = regionCount
145+
? std::string(llvm::formatv(
146+
R"(for (unsigned i = 0; i != {0}; ++i) {{
147+
(void)odsState.addRegion();
148+
}
149+
)",
150+
regionCount))
151+
: "";
134152
}
135153

136154
/// Fills a dictionary with values from DialectStrings
@@ -179,6 +197,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
179197
const OpStrings &opStrings) {
180198
auto opGetters = std::string{};
181199
auto resGetters = std::string{};
200+
auto regionGetters = std::string{};
201+
auto regionAdaptorGetters = std::string{};
182202

183203
for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
184204
const auto op =
@@ -196,8 +216,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
196216
op, i);
197217
}
198218

219+
for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) {
220+
const auto op =
221+
llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true);
222+
regionAdaptorGetters += llvm::formatv(
223+
R"(::mlir::Region &get{0}() { return *getRegions()[{1}]; }
224+
)",
225+
op, i);
226+
regionGetters += llvm::formatv(
227+
R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
228+
)",
229+
op, i);
230+
}
231+
199232
dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
200233
dict["OP_RESULT_GETTER_DECLS"] = resGetters;
234+
dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters;
235+
dict["OP_REGION_GETTER_DECLS"] = regionGetters;
201236
}
202237

203238
static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
@@ -238,6 +273,23 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
238273
dict["OP_BUILD_DECLS"] = buildDecls;
239274
}
240275

276+
// add traits to the dictionary, return true if any were added
277+
static SmallVector<std::string> generateTraits(irdl::detail::dictionary &dict,
278+
irdl::OperationOp op,
279+
const OpStrings &strings) {
280+
SmallVector<std::string> cppTraitNames;
281+
if (!strings.opRegionNames.empty()) {
282+
cppTraitNames.push_back(
283+
llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
284+
strings.opRegionNames.size())
285+
.str());
286+
287+
// Requires verifyInvariantsImpl is implemented on the op
288+
cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants");
289+
}
290+
return cppTraitNames;
291+
}
292+
241293
static LogicalResult generateOperationInclude(irdl::OperationOp op,
242294
raw_ostream &output,
243295
irdl::detail::dictionary &dict) {
@@ -247,6 +299,13 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op,
247299
const auto opStrings = getStrings(op);
248300
fillDict(dict, opStrings);
249301

302+
SmallVector<std::string> traitNames = generateTraits(dict, op, opStrings);
303+
if (traitNames.empty())
304+
dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
305+
else
306+
dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName,
307+
llvm::join(traitNames, ", "));
308+
250309
generateOpGetterDeclarations(dict, opStrings);
251310
generateOpBuilderDeclarations(dict, opStrings);
252311

@@ -301,6 +360,111 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
301360
return success();
302361
}
303362

363+
static void generateVerifiers(irdl::detail::dictionary &dict,
364+
irdl::OperationOp op, const OpStrings &strings) {
365+
SmallVector<std::string> verifierHelpers;
366+
SmallVector<std::string> verifierCalls;
367+
auto regionsOp = op.getOp<irdl::RegionsOp>();
368+
if (strings.opRegionNames.empty() || !regionsOp) {
369+
// Currently IRDL regions are the only reason to generate a nontrivial
370+
// verifier, though this will likely change as
371+
// https://github.com/llvm/llvm-project/issues/158040 is implemented
372+
std::string verifierDef = llvm::formatv(R"(
373+
::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
374+
return ::mlir::success();
375+
})", strings.opCppName);
376+
dict["OP_VERIFIER_HELPERS"] = "";
377+
dict["OP_VERIFIER"] = verifierDef;
378+
return;
379+
}
380+
381+
for (size_t i = 0; i < strings.opRegionNames.size(); ++i) {
382+
std::string regionName = strings.opRegionNames[i];
383+
std::string helperFnName =
384+
llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
385+
strings.opCppName, regionName)
386+
.str();
387+
388+
// Extract the actual region constraint from the IRDL RegionOp
389+
std::string condition = "true";
390+
std::string textualConditionName = "any region";
391+
392+
if (auto regionDefOp =
393+
dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) {
394+
// Generate constraint condition based on RegionOp attributes
395+
SmallVector<std::string> conditionParts;
396+
SmallVector<std::string> descriptionParts;
397+
398+
// Check number of blocks constraint
399+
if (auto blockCount = regionDefOp.getNumberOfBlocks()) {
400+
conditionParts.push_back(
401+
llvm::formatv("region.getBlocks().size() == {0}",
402+
blockCount.value())
403+
.str());
404+
descriptionParts.push_back(
405+
llvm::formatv("exactly {0} block(s)", blockCount.value()).str());
406+
}
407+
408+
// Check entry block arguments constraint
409+
if (regionDefOp.getConstrainedArguments()) {
410+
size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size();
411+
conditionParts.push_back(
412+
llvm::formatv("region.getNumArguments() == {0}", expectedArgCount)
413+
.str());
414+
descriptionParts.push_back(
415+
llvm::formatv("{0} entry block argument(s)", expectedArgCount)
416+
.str());
417+
}
418+
419+
// Combine conditions
420+
if (!conditionParts.empty()) {
421+
condition = llvm::join(conditionParts, " && ");
422+
}
423+
424+
// Generate descriptive error message
425+
if (!descriptionParts.empty()) {
426+
textualConditionName =
427+
llvm::formatv("region with {0}",
428+
llvm::join(descriptionParts, " and "))
429+
.str();
430+
}
431+
}
432+
433+
verifierHelpers.push_back(llvm::formatv(
434+
R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName, unsigned regionIndex) {{
435+
if (!({1})) {{
436+
return op->emitOpError("region #") << regionIndex
437+
<< (regionName.empty() ? " " : " ('" + regionName + "') ")
438+
<< "failed to verify constraint: {2}";
439+
}
440+
return ::mlir::success();
441+
})",
442+
helperFnName, condition, textualConditionName));
443+
444+
verifierCalls.push_back(llvm::formatv(R"(
445+
if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1})))
446+
return ::mlir::failure();)",
447+
helperFnName, i, regionName)
448+
.str());
449+
}
450+
451+
// Add an overall verifier that sequences the helper calls
452+
std::string verifierDef =
453+
llvm::formatv(R"(
454+
::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
455+
if(::mlir::failed(verify()))
456+
return ::mlir::failure();
457+
458+
{1}
459+
460+
return ::mlir::success();
461+
})",
462+
strings.opCppName, llvm::join(verifierCalls, "\n"));
463+
464+
dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n");
465+
dict["OP_VERIFIER"] = verifierDef;
466+
}
467+
304468
static std::string generateOpDefinition(irdl::detail::dictionary &dict,
305469
irdl::OperationOp op) {
306470
static const auto perOpDefTemplate = mlir::irdl::detail::Template{
@@ -370,6 +534,8 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {
370534

371535
dict["OP_BUILD_DEFS"] = buildDefinition;
372536

537+
generateVerifiers(dict, op, opStrings);
538+
373539
std::string str;
374540
llvm::raw_string_ostream stream{str};
375541
perOpDefTemplate.render(stream, dict);
@@ -427,7 +593,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
427593
dict["TYPE_PARSER"] = llvm::formatv(
428594
R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
429595
return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
430-
{0}
596+
{0}
431597
.Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
432598
*mnemonic = keyword;
433599
return std::nullopt;
@@ -520,6 +686,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
520686
"IRDL C++ translation does not yet support variadic results");
521687
}))
522688
.Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
689+
.Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); }))
690+
.Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); }))
523691
.Default([](mlir::Operation *op) -> LogicalResult {
524692
return op->emitError("IRDL C++ translation does not yet support "
525693
"translation of ")

0 commit comments

Comments
 (0)