@@ -54,6 +54,7 @@ struct OpStrings {
54
54
std::string opCppName;
55
55
SmallVector<std::string> opResultNames;
56
56
SmallVector<std::string> opOperandNames;
57
+ SmallVector<std::string> opRegionNames;
57
58
};
58
59
59
60
static std::string joinNameList (llvm::ArrayRef<std::string> names) {
@@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) {
87
88
// / Generates OpStrings from an OperatioOp
88
89
static OpStrings getStrings (irdl::OperationOp op) {
89
90
auto operandOp = op.getOp <irdl::OperandsOp>();
90
-
91
91
auto resultOp = op.getOp <irdl::ResultsOp>();
92
+ auto regionsOp = op.getOp <irdl::RegionsOp>();
92
93
93
94
OpStrings strings;
94
95
strings.opName = op.getSymName ();
@@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) {
108
109
}));
109
110
}
110
111
112
+ if (regionsOp) {
113
+ strings.opRegionNames = SmallVector<std::string>(
114
+ llvm::map_range (regionsOp->getNames (), [](Attribute attr) {
115
+ return llvm::formatv (" {0}" , cast<StringAttr>(attr));
116
+ }));
117
+ }
118
+
111
119
return strings;
112
120
}
113
121
@@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict,
122
130
static void fillDict (irdl::detail::dictionary &dict, const OpStrings &strings) {
123
131
const auto operandCount = strings.opOperandNames .size ();
124
132
const auto resultCount = strings.opResultNames .size ();
133
+ const auto regionCount = strings.opRegionNames .size ();
125
134
126
135
dict[" OP_NAME" ] = strings.opName ;
127
136
dict[" OP_CPP_NAME" ] = strings.opCppName ;
@@ -131,6 +140,15 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
131
140
operandCount ? joinNameList (strings.opOperandNames ) : " {\"\" }" ;
132
141
dict[" OP_RESULT_INITIALIZER_LIST" ] =
133
142
resultCount ? joinNameList (strings.opResultNames ) : " {\"\" }" ;
143
+ dict[" OP_REGION_COUNT" ] = std::to_string (regionCount);
144
+ dict[" OP_ADD_REGIONS" ] = regionCount
145
+ ? std::string (llvm::formatv (
146
+ R"( for (unsigned i = 0; i != {0}; ++i) {{
147
+ (void)odsState.addRegion();
148
+ }
149
+ )" ,
150
+ regionCount))
151
+ : " " ;
134
152
}
135
153
136
154
// / Fills a dictionary with values from DialectStrings
@@ -179,6 +197,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
179
197
const OpStrings &opStrings) {
180
198
auto opGetters = std::string{};
181
199
auto resGetters = std::string{};
200
+ auto regionGetters = std::string{};
201
+ auto regionAdaptorGetters = std::string{};
182
202
183
203
for (size_t i = 0 , end = opStrings.opOperandNames .size (); i < end; ++i) {
184
204
const auto op =
@@ -196,8 +216,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
196
216
op, i);
197
217
}
198
218
219
+ for (size_t i = 0 , end = opStrings.opRegionNames .size (); i < end; ++i) {
220
+ const auto op =
221
+ llvm::convertToCamelFromSnakeCase (opStrings.opRegionNames [i], true );
222
+ regionAdaptorGetters += llvm::formatv (
223
+ R"( ::mlir::Region &get{0}() { return *getRegions()[{1}]; }
224
+ )" ,
225
+ op, i);
226
+ regionGetters += llvm::formatv (
227
+ R"( ::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
228
+ )" ,
229
+ op, i);
230
+ }
231
+
199
232
dict[" OP_OPERAND_GETTER_DECLS" ] = opGetters;
200
233
dict[" OP_RESULT_GETTER_DECLS" ] = resGetters;
234
+ dict[" OP_REGION_ADAPTER_GETTER_DECLS" ] = regionAdaptorGetters;
235
+ dict[" OP_REGION_GETTER_DECLS" ] = regionGetters;
201
236
}
202
237
203
238
static void generateOpBuilderDeclarations (irdl::detail::dictionary &dict,
@@ -238,6 +273,23 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
238
273
dict[" OP_BUILD_DECLS" ] = buildDecls;
239
274
}
240
275
276
+ // add traits to the dictionary, return true if any were added
277
+ static SmallVector<std::string> generateTraits (irdl::detail::dictionary &dict,
278
+ irdl::OperationOp op,
279
+ const OpStrings &strings) {
280
+ SmallVector<std::string> cppTraitNames;
281
+ if (!strings.opRegionNames .empty ()) {
282
+ cppTraitNames.push_back (
283
+ llvm::formatv (" ::mlir::OpTrait::NRegions<{0}>::Impl" ,
284
+ strings.opRegionNames .size ())
285
+ .str ());
286
+
287
+ // Requires verifyInvariantsImpl is implemented on the op
288
+ cppTraitNames.emplace_back (" ::mlir::OpTrait::OpInvariants" );
289
+ }
290
+ return cppTraitNames;
291
+ }
292
+
241
293
static LogicalResult generateOperationInclude (irdl::OperationOp op,
242
294
raw_ostream &output,
243
295
irdl::detail::dictionary &dict) {
@@ -247,6 +299,13 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op,
247
299
const auto opStrings = getStrings (op);
248
300
fillDict (dict, opStrings);
249
301
302
+ SmallVector<std::string> traitNames = generateTraits (dict, op, opStrings);
303
+ if (traitNames.empty ())
304
+ dict[" OP_TEMPLATE_ARGS" ] = opStrings.opCppName ;
305
+ else
306
+ dict[" OP_TEMPLATE_ARGS" ] = llvm::formatv (" {0}, {1}" , opStrings.opCppName ,
307
+ llvm::join (traitNames, " , " ));
308
+
250
309
generateOpGetterDeclarations (dict, opStrings);
251
310
generateOpBuilderDeclarations (dict, opStrings);
252
311
@@ -301,6 +360,111 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
301
360
return success ();
302
361
}
303
362
363
+ static void generateVerifiers (irdl::detail::dictionary &dict,
364
+ irdl::OperationOp op, const OpStrings &strings) {
365
+ SmallVector<std::string> verifierHelpers;
366
+ SmallVector<std::string> verifierCalls;
367
+ auto regionsOp = op.getOp <irdl::RegionsOp>();
368
+ if (strings.opRegionNames .empty () || !regionsOp) {
369
+ // Currently IRDL regions are the only reason to generate a nontrivial
370
+ // verifier, though this will likely change as
371
+ // https://github.com/llvm/llvm-project/issues/158040 is implemented
372
+ std::string verifierDef = llvm::formatv (R"(
373
+ ::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
374
+ return ::mlir::success();
375
+ })" , strings.opCppName );
376
+ dict[" OP_VERIFIER_HELPERS" ] = " " ;
377
+ dict[" OP_VERIFIER" ] = verifierDef;
378
+ return ;
379
+ }
380
+
381
+ for (size_t i = 0 ; i < strings.opRegionNames .size (); ++i) {
382
+ std::string regionName = strings.opRegionNames [i];
383
+ std::string helperFnName =
384
+ llvm::formatv (" __mlir_irdl_local_region_constraint_{0}_{1}" ,
385
+ strings.opCppName , regionName)
386
+ .str ();
387
+
388
+ // Extract the actual region constraint from the IRDL RegionOp
389
+ std::string condition = " true" ;
390
+ std::string textualConditionName = " any region" ;
391
+
392
+ if (auto regionDefOp =
393
+ dyn_cast<irdl::RegionOp>(regionsOp->getArgs ()[i].getDefiningOp ())) {
394
+ // Generate constraint condition based on RegionOp attributes
395
+ SmallVector<std::string> conditionParts;
396
+ SmallVector<std::string> descriptionParts;
397
+
398
+ // Check number of blocks constraint
399
+ if (auto blockCount = regionDefOp.getNumberOfBlocks ()) {
400
+ conditionParts.push_back (
401
+ llvm::formatv (" region.getBlocks().size() == {0}" ,
402
+ blockCount.value ())
403
+ .str ());
404
+ descriptionParts.push_back (
405
+ llvm::formatv (" exactly {0} block(s)" , blockCount.value ()).str ());
406
+ }
407
+
408
+ // Check entry block arguments constraint
409
+ if (regionDefOp.getConstrainedArguments ()) {
410
+ size_t expectedArgCount = regionDefOp.getEntryBlockArgs ().size ();
411
+ conditionParts.push_back (
412
+ llvm::formatv (" region.getNumArguments() == {0}" , expectedArgCount)
413
+ .str ());
414
+ descriptionParts.push_back (
415
+ llvm::formatv (" {0} entry block argument(s)" , expectedArgCount)
416
+ .str ());
417
+ }
418
+
419
+ // Combine conditions
420
+ if (!conditionParts.empty ()) {
421
+ condition = llvm::join (conditionParts, " && " );
422
+ }
423
+
424
+ // Generate descriptive error message
425
+ if (!descriptionParts.empty ()) {
426
+ textualConditionName =
427
+ llvm::formatv (" region with {0}" ,
428
+ llvm::join (descriptionParts, " and " ))
429
+ .str ();
430
+ }
431
+ }
432
+
433
+ verifierHelpers.push_back (llvm::formatv (
434
+ R"( static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, unsigned regionIndex) {{
435
+ if (!({1})) {{
436
+ return op->emitOpError("region #") << regionIndex
437
+ << (regionName.empty() ? " " : " ('" + regionName + "') ")
438
+ << "failed to verify constraint: {2}";
439
+ }
440
+ return ::mlir::success();
441
+ })" ,
442
+ helperFnName, condition, textualConditionName));
443
+
444
+ verifierCalls.push_back (llvm::formatv (R"(
445
+ if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1})))
446
+ return ::mlir::failure();)" ,
447
+ helperFnName, i, regionName)
448
+ .str ());
449
+ }
450
+
451
+ // Add an overall verifier that sequences the helper calls
452
+ std::string verifierDef =
453
+ llvm::formatv (R"(
454
+ ::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
455
+ if(::mlir::failed(verify()))
456
+ return ::mlir::failure();
457
+
458
+ {1}
459
+
460
+ return ::mlir::success();
461
+ })" ,
462
+ strings.opCppName , llvm::join (verifierCalls, " \n " ));
463
+
464
+ dict[" OP_VERIFIER_HELPERS" ] = llvm::join (verifierHelpers, " \n " );
465
+ dict[" OP_VERIFIER" ] = verifierDef;
466
+ }
467
+
304
468
static std::string generateOpDefinition (irdl::detail::dictionary &dict,
305
469
irdl::OperationOp op) {
306
470
static const auto perOpDefTemplate = mlir::irdl::detail::Template{
@@ -370,6 +534,8 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {
370
534
371
535
dict[" OP_BUILD_DEFS" ] = buildDefinition;
372
536
537
+ generateVerifiers (dict, op, opStrings);
538
+
373
539
std::string str;
374
540
llvm::raw_string_ostream stream{str};
375
541
perOpDefTemplate.render (stream, dict);
@@ -427,7 +593,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
427
593
dict[" TYPE_PARSER" ] = llvm::formatv (
428
594
R"( static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
429
595
return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
430
- {0}
596
+ {0}
431
597
.Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
432
598
*mnemonic = keyword;
433
599
return std::nullopt;
@@ -520,6 +686,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
520
686
" IRDL C++ translation does not yet support variadic results" );
521
687
}))
522
688
.Case <irdl::AnyOp>(([](irdl::AnyOp) { return success (); }))
689
+ .Case <irdl::RegionOp>(([](irdl::RegionOp) { return success (); }))
690
+ .Case <irdl::RegionsOp>(([](irdl::RegionsOp) { return success (); }))
523
691
.Default ([](mlir::Operation *op) -> LogicalResult {
524
692
return op->emitError (" IRDL C++ translation does not yet support "
525
693
" translation of " )
0 commit comments