|
| 1 | +//===- DialectInterfacesGen.cpp - MLIR dialect interface utility generator ===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// DialectInterfaceGen generates definitions for Dialect interfaces. |
| 10 | +// |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | + |
| 13 | +#include "CppGenUtilities.h" |
| 14 | +#include "DocGenUtilities.h" |
| 15 | +#include "mlir/TableGen/GenInfo.h" |
| 16 | +#include "mlir/TableGen/Interfaces.h" |
| 17 | +#include "llvm/ADT/StringExtras.h" |
| 18 | +#include "llvm/Support/FormatVariadic.h" |
| 19 | +#include "llvm/Support/raw_ostream.h" |
| 20 | +#include "llvm/TableGen/CodeGenHelpers.h" |
| 21 | +#include "llvm/TableGen/Error.h" |
| 22 | +#include "llvm/TableGen/Record.h" |
| 23 | +#include "llvm/TableGen/TableGenBackend.h" |
| 24 | + |
| 25 | +using namespace mlir; |
| 26 | +using llvm::Record; |
| 27 | +using llvm::RecordKeeper; |
| 28 | +using mlir::tblgen::Interface; |
| 29 | +using mlir::tblgen::InterfaceMethod; |
| 30 | + |
| 31 | +/// Emit a string corresponding to a C++ type, followed by a space if necessary. |
| 32 | +static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) { |
| 33 | + type = type.trim(); |
| 34 | + os << type; |
| 35 | + if (type.back() != '&' && type.back() != '*') |
| 36 | + os << " "; |
| 37 | + return os; |
| 38 | +} |
| 39 | + |
| 40 | +/// Emit the method name and argument list for the given method. |
| 41 | +static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name, |
| 42 | + raw_ostream &os) { |
| 43 | + os << name << '('; |
| 44 | + llvm::interleaveComma(method.getArguments(), os, |
| 45 | + [&](const InterfaceMethod::Argument &arg) { |
| 46 | + os << arg.type << " " << arg.name; |
| 47 | + }); |
| 48 | + os << ") const"; |
| 49 | +} |
| 50 | + |
| 51 | +/// Get an array of all Dialect Interface definitions |
| 52 | +static std::vector<const Record *> |
| 53 | +getAllInterfaceDefinitions(const RecordKeeper &records) { |
| 54 | + std::vector<const Record *> defs = |
| 55 | + records.getAllDerivedDefinitions("DialectInterface"); |
| 56 | + |
| 57 | + llvm::erase_if(defs, [&](const Record *def) { |
| 58 | + // Ignore interfaces defined outside of the top-level file. |
| 59 | + return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) != |
| 60 | + llvm::SrcMgr.getMainFileID(); |
| 61 | + }); |
| 62 | + return defs; |
| 63 | +} |
| 64 | + |
| 65 | +namespace { |
| 66 | +/// This struct is the generator used when processing tablegen dialect |
| 67 | +/// interfaces. |
| 68 | +class DialectInterfaceGenerator { |
| 69 | +public: |
| 70 | + DialectInterfaceGenerator(const RecordKeeper &records, raw_ostream &os) |
| 71 | + : defs(getAllInterfaceDefinitions(records)), os(os) {} |
| 72 | + |
| 73 | + bool emitInterfaceDecls(); |
| 74 | + |
| 75 | +protected: |
| 76 | + void emitInterfaceDecl(const Interface &interface); |
| 77 | + void emitInterfaceMethodsDef(const Interface &interface); |
| 78 | + |
| 79 | + /// The set of interface records to emit. |
| 80 | + std::vector<const Record *> defs; |
| 81 | + // The stream to emit to. |
| 82 | + raw_ostream &os; |
| 83 | +}; |
| 84 | +} // namespace |
| 85 | + |
| 86 | +//===----------------------------------------------------------------------===// |
| 87 | +// GEN: Interface declarations |
| 88 | +//===----------------------------------------------------------------------===// |
| 89 | + |
| 90 | +static void emitInterfaceMethodDoc(const InterfaceMethod &method, |
| 91 | + raw_ostream &os, StringRef prefix = "") { |
| 92 | + if (std::optional<StringRef> description = method.getDescription()) |
| 93 | + tblgen::emitDescriptionComment(*description, os, prefix); |
| 94 | +} |
| 95 | + |
| 96 | +static void emitInterfaceDeclMethods(const Interface &interface, |
| 97 | + raw_ostream &os) { |
| 98 | + for (auto &method : interface.getMethods()) { |
| 99 | + emitInterfaceMethodDoc(method, os, " "); |
| 100 | + os << " virtual "; |
| 101 | + emitCPPType(method.getReturnType(), os); |
| 102 | + emitMethodNameAndArgs(method, method.getName(), os); |
| 103 | + if (!method.getBody()) |
| 104 | + // no default method body |
| 105 | + os << " = 0"; |
| 106 | + os << ";\n"; |
| 107 | + } |
| 108 | +} |
| 109 | + |
| 110 | +void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) { |
| 111 | + llvm::NamespaceEmitter ns(os, interface.getCppNamespace()); |
| 112 | + |
| 113 | + StringRef interfaceName = interface.getName(); |
| 114 | + |
| 115 | + tblgen::emitSummaryAndDescComments(os, "", |
| 116 | + interface.getDescription().value_or("")); |
| 117 | + |
| 118 | + // Emit the main interface class declaration. |
| 119 | + os << llvm::formatv( |
| 120 | + "class {0} : public ::mlir::DialectInterface::Base<{0}> {{\n" |
| 121 | + "public:\n", |
| 122 | + interfaceName); |
| 123 | + |
| 124 | + emitInterfaceDeclMethods(interface, os); |
| 125 | + os << llvm::formatv("\nprotected:\n" |
| 126 | + " {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n", |
| 127 | + interfaceName); |
| 128 | + |
| 129 | + os << "};\n"; |
| 130 | +} |
| 131 | + |
| 132 | +void DialectInterfaceGenerator::emitInterfaceMethodsDef( |
| 133 | + const Interface &interface) { |
| 134 | + |
| 135 | + for (auto &method : interface.getMethods()) { |
| 136 | + if (auto body = method.getBody()) { |
| 137 | + emitCPPType(method.getReturnType(), os); |
| 138 | + os << interface.getCppNamespace() << "::"; |
| 139 | + os << interface.getName() << "::"; |
| 140 | + emitMethodNameAndArgs(method, method.getName(), os); |
| 141 | + os << " {\n " << body.value() << "\n}\n"; |
| 142 | + } |
| 143 | + } |
| 144 | +} |
| 145 | + |
| 146 | +bool DialectInterfaceGenerator::emitInterfaceDecls() { |
| 147 | + |
| 148 | + llvm::emitSourceFileHeader("Dialect Interface Declarations", os); |
| 149 | + |
| 150 | + // Sort according to ID, so defs are emitted in the order in which they appear |
| 151 | + // in the Tablegen file. |
| 152 | + std::vector<const Record *> sortedDefs(defs); |
| 153 | + llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) { |
| 154 | + return lhs->getID() < rhs->getID(); |
| 155 | + }); |
| 156 | + |
| 157 | + for (const Record *def : sortedDefs) |
| 158 | + emitInterfaceDecl(Interface(def)); |
| 159 | + |
| 160 | + os << "\n"; |
| 161 | + for (const Record *def : sortedDefs) |
| 162 | + emitInterfaceMethodsDef(Interface(def)); |
| 163 | + |
| 164 | + return false; |
| 165 | +} |
| 166 | + |
| 167 | +//===----------------------------------------------------------------------===// |
| 168 | +// GEN: Interface registration hooks |
| 169 | +//===----------------------------------------------------------------------===// |
| 170 | + |
| 171 | +static mlir::GenRegistration genDecls( |
| 172 | + "gen-dialect-interface-decls", |
| 173 | + "Generate dialect interface declarations.", |
| 174 | + [](const RecordKeeper &records, raw_ostream &os) { |
| 175 | + return DialectInterfaceGenerator(records, os).emitInterfaceDecls(); |
| 176 | + }); |
0 commit comments