@@ -1083,15 +1083,15 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
10831083}
10841084
10851085// ===----------------------------------------------------------------------===//
1086- // Type Constraints
1086+ // Constraints
10871087// ===----------------------------------------------------------------------===//
10881088
10891089// / Find all type constraints for which a C++ function should be generated.
1090- static std::vector<Constraint>
1091- getAllTypeConstraints ( const RecordKeeper &records ) {
1090+ static std::vector<Constraint> getAllCppConstraints ( const RecordKeeper &records,
1091+ StringRef constraintKind ) {
10921092 std::vector<Constraint> result;
10931093 for (const Record *def :
1094- records.getAllDerivedDefinitionsIfDefined (" TypeConstraint " )) {
1094+ records.getAllDerivedDefinitionsIfDefined (constraintKind )) {
10951095 // Ignore constraints defined outside of the top-level file.
10961096 if (llvm::SrcMgr.FindBufferContainingLoc (def->getLoc ()[0 ]) !=
10971097 llvm::SrcMgr.getMainFileID ())
@@ -1105,32 +1105,74 @@ getAllTypeConstraints(const RecordKeeper &records) {
11051105 return result;
11061106}
11071107
1108+ static std::vector<Constraint>
1109+ getAllCppTypeConstraints (const RecordKeeper &records) {
1110+ return getAllCppConstraints (records, " TypeConstraint" );
1111+ }
1112+
1113+ static std::vector<Constraint>
1114+ getAllCppAttrConstraints (const RecordKeeper &records) {
1115+ return getAllCppConstraints (records, " AttrConstraint" );
1116+ }
1117+
1118+ // / Emit the declarations for the given constraints, of the form:
1119+ // / `bool <constraintCppFunctionName>(<parameterTypeName> <parameterName>);`
1120+ static void emitConstraintDecls (const std::vector<Constraint> &constraints,
1121+ raw_ostream &os, StringRef parameterTypeName,
1122+ StringRef parameterName) {
1123+ static const char *const constraintDecl = " bool {0}({1} {2});\n " ;
1124+ for (Constraint constr : constraints)
1125+ os << strfmt (constraintDecl, *constr.getCppFunctionName (),
1126+ parameterTypeName, parameterName);
1127+ }
1128+
11081129static void emitTypeConstraintDecls (const RecordKeeper &records,
11091130 raw_ostream &os) {
1110- static const char * const typeConstraintDecl = R"(
1111- bool {0}(::mlir::Type type);
1112- )" ;
1131+ emitConstraintDecls ( getAllCppTypeConstraints (records), os, " ::mlir::Type " ,
1132+ " type" );
1133+ }
11131134
1114- for (Constraint constr : getAllTypeConstraints (records))
1115- os << strfmt (typeConstraintDecl, *constr.getCppFunctionName ());
1135+ static void emitAttrConstraintDecls (const RecordKeeper &records,
1136+ raw_ostream &os) {
1137+ emitConstraintDecls (getAllCppAttrConstraints (records), os,
1138+ " ::mlir::Attribute" , " attr" );
11161139}
11171140
1118- static void emitTypeConstraintDefs (const RecordKeeper &records,
1119- raw_ostream &os) {
1120- static const char *const typeConstraintDef = R"(
1121- bool {0}(::mlir::Type type) {
1122- return ({1});
1141+ // / Emit the definitions for the given constraints, of the form:
1142+ // / `bool <constraintCppFunctionName>(<parameterTypeName> <parameterName>) {
1143+ // / return (<condition>); }`
1144+ // / where `<condition>` is the condition template with the `self` variable
1145+ // / replaced with the `selfName` parameter.
1146+ static void emitConstraintDefs (const std::vector<Constraint> &constraints,
1147+ raw_ostream &os, StringRef parameterTypeName,
1148+ StringRef selfName) {
1149+ static const char *const constraintDef = R"(
1150+ bool {0}({1} {2}) {
1151+ return ({3});
11231152}
11241153)" ;
11251154
1126- for (Constraint constr : getAllTypeConstraints (records) ) {
1155+ for (Constraint constr : constraints ) {
11271156 FmtContext ctx;
1128- ctx.withSelf (" type " );
1157+ ctx.withSelf (selfName );
11291158 std::string condition = tgfmt (constr.getConditionTemplate (), &ctx);
1130- os << strfmt (typeConstraintDef, *constr.getCppFunctionName (), condition);
1159+ os << strfmt (constraintDef, *constr.getCppFunctionName (), parameterTypeName,
1160+ selfName, condition);
11311161 }
11321162}
11331163
1164+ static void emitTypeConstraintDefs (const RecordKeeper &records,
1165+ raw_ostream &os) {
1166+ emitConstraintDefs (getAllCppTypeConstraints (records), os, " ::mlir::Type" ,
1167+ " type" );
1168+ }
1169+
1170+ static void emitAttrConstraintDefs (const RecordKeeper &records,
1171+ raw_ostream &os) {
1172+ emitConstraintDefs (getAllCppAttrConstraints (records), os, " ::mlir::Attribute" ,
1173+ " attr" );
1174+ }
1175+
11341176// ===----------------------------------------------------------------------===//
11351177// GEN: Registration hooks
11361178// ===----------------------------------------------------------------------===//
@@ -1158,6 +1200,21 @@ static mlir::GenRegistration
11581200 return generator.emitDecls (attrDialect);
11591201 });
11601202
1203+ static mlir::GenRegistration
1204+ genAttrConstrDefs (" gen-attr-constraint-defs" ,
1205+ " Generate attribute constraint definitions" ,
1206+ [](const RecordKeeper &records, raw_ostream &os) {
1207+ emitAttrConstraintDefs (records, os);
1208+ return false ;
1209+ });
1210+ static mlir::GenRegistration
1211+ genAttrConstrDecls (" gen-attr-constraint-decls" ,
1212+ " Generate attribute constraint declarations" ,
1213+ [](const RecordKeeper &records, raw_ostream &os) {
1214+ emitAttrConstraintDecls (records, os);
1215+ return false ;
1216+ });
1217+
11611218// ===----------------------------------------------------------------------===//
11621219// TypeDef
11631220// ===----------------------------------------------------------------------===//
0 commit comments