-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR] feat(mlir-tblgen): Add support for dialect interfaces #170046
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
504bb5c
4ca9d25
e3d4a43
3f0fc8b
aa94279
0531140
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| // RUN: mlir-tblgen -gen-dialect-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL | ||
|
|
||
| include "mlir/IR/Interfaces.td" | ||
|
|
||
| def NoDefaultMethod : DialectInterface<"NoDefaultMethod"> { | ||
| let description = [{ | ||
| This is an example dialect interface without default method body. | ||
| }]; | ||
|
|
||
| let cppNamespace = "::mlir::example"; | ||
|
|
||
| let methods = [ | ||
| InterfaceMethod< | ||
| /*desc=*/ "Check if it's an example dialect", | ||
| /*returnType=*/ "bool", | ||
| /*methodName=*/ "isExampleDialect", | ||
| /*args=*/ (ins) | ||
| >, | ||
| InterfaceMethod< | ||
| /*desc=*/ "second method to check if multiple methods supported", | ||
| /*returnType=*/ "unsigned", | ||
| /*methodName=*/ "supportSecondMethod", | ||
| /*args=*/ (ins "::mlir::Type":$type) | ||
| > | ||
|
|
||
| ]; | ||
| } | ||
|
|
||
| // DECL: class NoDefaultMethod : public {{.*}}DialectInterface::Base<NoDefaultMethod> | ||
| // DECL: virtual bool isExampleDialect() const = 0; | ||
| // DECL: virtual unsigned supportSecondMethod(::mlir::Type type) const = 0; | ||
| // DECL: protected: | ||
| // DECL-NEXT: NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {} | ||
|
|
||
| def WithDefaultMethodInterface : DialectInterface<"WithDefaultMethodInterface"> { | ||
| let description = [{ | ||
| This is an example dialect interface with default method bodies. | ||
| }]; | ||
|
|
||
| let cppNamespace = "::mlir::example"; | ||
|
|
||
| let methods = [ | ||
| InterfaceMethod< | ||
| /*desc=*/ "Check if it's an example dialect", | ||
| /*returnType=*/ "bool", | ||
| /*methodName=*/ "isExampleDialect", | ||
| /*args=*/ (ins), | ||
| /*methodBody=*/ [{ | ||
| return true; | ||
| }] | ||
| >, | ||
| InterfaceMethod< | ||
| /*desc=*/ "second method to check if multiple methods supported", | ||
| /*returnType=*/ "unsigned", | ||
| /*methodName=*/ "supportSecondMethod", | ||
| /*args=*/ (ins "::mlir::Type":$type) | ||
| > | ||
|
|
||
| ]; | ||
| } | ||
|
|
||
| // DECL: virtual bool isExampleDialect() const; | ||
| // DECL: bool ::mlir::example::WithDefaultMethodInterface::isExampleDialect() const { | ||
| // DECL-NEXT: return true; | ||
| // DECL-NEXT: } | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| //===- DialectInterfacesGen.cpp - MLIR dialect interface utility generator ===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // DialectInterfaceGen generates definitions for Dialect interfaces. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "CppGenUtilities.h" | ||
| #include "DocGenUtilities.h" | ||
| #include "mlir/TableGen/GenInfo.h" | ||
| #include "mlir/TableGen/Interfaces.h" | ||
| #include "llvm/ADT/StringExtras.h" | ||
| #include "llvm/Support/FormatVariadic.h" | ||
| #include "llvm/Support/raw_ostream.h" | ||
| #include "llvm/TableGen/CodeGenHelpers.h" | ||
| #include "llvm/TableGen/Error.h" | ||
| #include "llvm/TableGen/Record.h" | ||
| #include "llvm/TableGen/TableGenBackend.h" | ||
|
|
||
| using namespace mlir; | ||
| using llvm::Record; | ||
| using llvm::RecordKeeper; | ||
| using mlir::tblgen::Interface; | ||
| using mlir::tblgen::InterfaceMethod; | ||
|
|
||
| /// Emit a string corresponding to a C++ type, followed by a space if necessary. | ||
| static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) { | ||
| type = type.trim(); | ||
| os << type; | ||
| if (type.back() != '&' && type.back() != '*') | ||
| os << " "; | ||
| return os; | ||
| } | ||
|
|
||
| /// Emit the method name and argument list for the given method. | ||
| static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name, | ||
| raw_ostream &os) { | ||
| os << name << '('; | ||
| llvm::interleaveComma(method.getArguments(), os, | ||
| [&](const InterfaceMethod::Argument &arg) { | ||
| os << arg.type << " " << arg.name; | ||
| }); | ||
| os << ") const"; | ||
| } | ||
|
|
||
| /// Get an array of all Dialect Interface definitions | ||
| static std::vector<const Record *> | ||
| getAllInterfaceDefinitions(const RecordKeeper &records) { | ||
| std::vector<const Record *> defs = | ||
| records.getAllDerivedDefinitions("DialectInterface"); | ||
|
|
||
| llvm::erase_if(defs, [&](const Record *def) { | ||
| // Ignore interfaces defined outside of the top-level file. | ||
| return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) != | ||
| llvm::SrcMgr.getMainFileID(); | ||
| }); | ||
| return defs; | ||
| } | ||
|
|
||
| namespace { | ||
| /// This struct is the generator used when processing tablegen dialect | ||
| /// interfaces. | ||
| class DialectInterfaceGenerator { | ||
| public: | ||
| DialectInterfaceGenerator(const RecordKeeper &records, raw_ostream &os) | ||
| : defs(getAllInterfaceDefinitions(records)), os(os) {} | ||
|
|
||
| bool emitInterfaceDecls(); | ||
|
|
||
| protected: | ||
| void emitInterfaceDecl(const Interface &interface); | ||
| void emitInterfaceMethodsDef(const Interface &interface); | ||
|
|
||
| /// The set of interface records to emit. | ||
| std::vector<const Record *> defs; | ||
| // The stream to emit to. | ||
| raw_ostream &os; | ||
| }; | ||
| } // namespace | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // GEN: Interface declarations | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| static void emitInterfaceMethodDoc(const InterfaceMethod &method, | ||
| raw_ostream &os, StringRef prefix = "") { | ||
| if (std::optional<StringRef> description = method.getDescription()) | ||
| tblgen::emitDescriptionComment(*description, os, prefix); | ||
| } | ||
|
|
||
| static void emitInterfaceDeclMethods(const Interface &interface, | ||
| raw_ostream &os) { | ||
| for (auto &method : interface.getMethods()) { | ||
| emitInterfaceMethodDoc(method, os, " "); | ||
| os << " virtual "; | ||
| emitCPPType(method.getReturnType(), os); | ||
| emitMethodNameAndArgs(method, method.getName(), os); | ||
| if (!method.getBody()) | ||
| // no default method body | ||
| os << " = 0"; | ||
| os << ";\n"; | ||
| } | ||
| } | ||
|
|
||
| void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) { | ||
| llvm::NamespaceEmitter ns(os, interface.getCppNamespace()); | ||
|
|
||
| StringRef interfaceName = interface.getName(); | ||
|
|
||
| tblgen::emitSummaryAndDescComments(os, "", | ||
| interface.getDescription().value_or("")); | ||
|
|
||
| // Emit the main interface class declaration. | ||
| os << llvm::formatv( | ||
| "class {0} : public ::mlir::DialectInterface::Base<{0}> {{\n" | ||
| "public:\n", | ||
| interfaceName); | ||
|
|
||
| emitInterfaceDeclMethods(interface, os); | ||
| os << llvm::formatv("\nprotected:\n" | ||
| " {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n", | ||
| interfaceName); | ||
|
|
||
| os << "};\n"; | ||
| } | ||
|
|
||
| void DialectInterfaceGenerator::emitInterfaceMethodsDef( | ||
| const Interface &interface) { | ||
|
|
||
| for (auto &method : interface.getMethods()) { | ||
| if (auto body = method.getBody()) { | ||
| emitCPPType(method.getReturnType(), os); | ||
| os << interface.getCppNamespace() << "::"; | ||
| os << interface.getName() << "::"; | ||
| emitMethodNameAndArgs(method, method.getName(), os); | ||
| os << " {\n " << body.value() << "\n}\n"; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| bool DialectInterfaceGenerator::emitInterfaceDecls() { | ||
|
|
||
| llvm::emitSourceFileHeader("Dialect Interface Declarations", os); | ||
|
|
||
| // Sort according to ID, so defs are emitted in the order in which they appear | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could have sworn I had this exposed somewhere :)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I grabbed this one from |
||
| // in the Tablegen file. | ||
| std::vector<const Record *> sortedDefs(defs); | ||
| llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) { | ||
| return lhs->getID() < rhs->getID(); | ||
| }); | ||
|
|
||
| for (const Record *def : sortedDefs) | ||
| emitInterfaceDecl(Interface(def)); | ||
|
|
||
| os << "\n"; | ||
| for (const Record *def : sortedDefs) | ||
| emitInterfaceMethodsDef(Interface(def)); | ||
|
|
||
| return false; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // GEN: Interface registration hooks | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| static mlir::GenRegistration genDecls( | ||
| "gen-dialect-interface-decls", "Generate dialect interface declarations.", | ||
| [](const RecordKeeper &records, raw_ostream &os) { | ||
| return DialectInterfaceGenerator(records, os).emitInterfaceDecls(); | ||
| }); | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No change needed, but you could also use indented ostream here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used indented ostream. Can you please review again to see if that's the correct usage?