Skip to content

Commit 504bb5c

Browse files
committed
feat(mlir-tblgen): Add support for dialect interfaces
1 parent e413343 commit 504bb5c

File tree

6 files changed

+263
-0
lines changed

6 files changed

+263
-0
lines changed

mlir/include/mlir/IR/Interfaces.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,11 @@ class TypeInterface<string name, list<Interface> baseInterfaces = []>
147147
!if(!empty(cppNamespace),"", cppNamespace # "::") # name
148148
>;
149149

150+
// DialectInterface represents an interface registered to an operation.
151+
class DialectInterface<string name, list<Interface> baseInterfaces = []>
152+
: Interface<name, baseInterfaces>, OpInterfaceTrait<name>;
153+
154+
150155
// Whether to declare the interface methods in the user entity's header. This
151156
// class simply wraps an Interface but is used to indicate that the method
152157
// declarations should be generated. This class takes an optional set of methods

mlir/include/mlir/TableGen/Interfaces.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,13 @@ struct TypeInterface : public Interface {
157157

158158
static bool classof(const Interface *interface);
159159
};
160+
// An interface that is registered to a Dialect.
161+
struct DialectInterface : public Interface {
162+
using Interface::Interface;
163+
164+
static bool classof(const Interface *interface);
165+
};
166+
160167
} // namespace tblgen
161168
} // namespace mlir
162169

mlir/lib/TableGen/Interfaces.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,11 @@ bool OpInterface::classof(const Interface *interface) {
208208
bool TypeInterface::classof(const Interface *interface) {
209209
return interface->getDef().isSubClassOf("TypeInterface");
210210
}
211+
212+
//===----------------------------------------------------------------------===//
213+
// DialectInterface
214+
//===----------------------------------------------------------------------===//
215+
216+
bool DialectInterface::classof(const Interface *interface) {
217+
return interface->getDef().isSubClassOf("DialectInterface");
218+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// RUN: mlir-tblgen -gen-dialect-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
2+
3+
include "mlir/IR/Interfaces.td"
4+
5+
def NoDefaultMethod : DialectInterface<"NoDefaultMethod"> {
6+
let description = [{
7+
This is an example dialect interface without default method body.
8+
}];
9+
10+
let cppNamespace = "::mlir::example";
11+
12+
let methods = [
13+
InterfaceMethod<
14+
/*desc=*/ "Check if it's an example dialect",
15+
/*returnType=*/ "bool",
16+
/*methodName=*/ "isExampleDialect",
17+
/*args=*/ (ins)
18+
>,
19+
InterfaceMethod<
20+
/*desc=*/ "second method to check if multiple methods supported",
21+
/*returnType=*/ "unsigned",
22+
/*methodName=*/ "supportSecondMethod",
23+
/*args=*/ (ins "::mlir::Type":$type)
24+
>
25+
26+
];
27+
}
28+
29+
// DECL: class NoDefaultMethod : public {{.*}}DialectInterface::Base<NoDefaultMethod>
30+
// DECL: virtual bool isExampleDialect() const = 0;
31+
// DECL: virtual unsigned supportSecondMethod(::mlir::Type type) const = 0;
32+
// DECL: protected:
33+
// DECL-NEXT: NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {}
34+
35+
def WithDefaultMethodInterface : DialectInterface<"WithDefaultMethodInterface"> {
36+
let description = [{
37+
This is an example dialect interface with default method bodies.
38+
}];
39+
40+
let cppNamespace = "::mlir::example";
41+
42+
let methods = [
43+
InterfaceMethod<
44+
/*desc=*/ "Check if it's an example dialect",
45+
/*returnType=*/ "bool",
46+
/*methodName=*/ "isExampleDialect",
47+
/*args=*/ (ins),
48+
/*methodBody=*/ [{
49+
return true;
50+
}]
51+
>,
52+
InterfaceMethod<
53+
/*desc=*/ "second method to check if multiple methods supported",
54+
/*returnType=*/ "unsigned",
55+
/*methodName=*/ "supportSecondMethod",
56+
/*args=*/ (ins "::mlir::Type":$type)
57+
>
58+
59+
];
60+
}
61+
62+
// DECL: virtual bool isExampleDialect() const;
63+
// DECL: bool ::mlir::example::WithDefaultMethodInterface::isExampleDialect() const {
64+
// DECL-NEXT: return true;
65+
// DECL-NEXT: }
66+

mlir/tools/mlir-tblgen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_tablegen(mlir-tblgen MLIR
1212
AttrOrTypeFormatGen.cpp
1313
BytecodeDialectGen.cpp
1414
DialectGen.cpp
15+
DialectInterfacesGen.cpp
1516
DirectiveCommonGen.cpp
1617
EnumsGen.cpp
1718
EnumPythonBindingGen.cpp
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
//===- DialectInterfacesGen.cpp - MLIR dialect interface utility generator ===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// DialectInterfaceGen generates definitions for Dialect interfaces.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "CppGenUtilities.h"
14+
#include "DocGenUtilities.h"
15+
#include "mlir/TableGen/GenInfo.h"
16+
#include "mlir/TableGen/Interfaces.h"
17+
#include "llvm/ADT/StringExtras.h"
18+
#include "llvm/Support/FormatVariadic.h"
19+
#include "llvm/Support/raw_ostream.h"
20+
#include "llvm/TableGen/CodeGenHelpers.h"
21+
#include "llvm/TableGen/Error.h"
22+
#include "llvm/TableGen/Record.h"
23+
#include "llvm/TableGen/TableGenBackend.h"
24+
25+
using namespace mlir;
26+
using llvm::Record;
27+
using llvm::RecordKeeper;
28+
using mlir::tblgen::Interface;
29+
using mlir::tblgen::InterfaceMethod;
30+
31+
/// Emit a string corresponding to a C++ type, followed by a space if necessary.
32+
static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
33+
type = type.trim();
34+
os << type;
35+
if (type.back() != '&' && type.back() != '*')
36+
os << " ";
37+
return os;
38+
}
39+
40+
/// Emit the method name and argument list for the given method.
41+
static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
42+
raw_ostream &os) {
43+
os << name << '(';
44+
llvm::interleaveComma(method.getArguments(), os,
45+
[&](const InterfaceMethod::Argument &arg) {
46+
os << arg.type << " " << arg.name;
47+
});
48+
os << ") const";
49+
}
50+
51+
/// Get an array of all Dialect Interface definitions
52+
static std::vector<const Record *>
53+
getAllInterfaceDefinitions(const RecordKeeper &records) {
54+
std::vector<const Record *> defs =
55+
records.getAllDerivedDefinitions("DialectInterface");
56+
57+
llvm::erase_if(defs, [&](const Record *def) {
58+
// Ignore interfaces defined outside of the top-level file.
59+
return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
60+
llvm::SrcMgr.getMainFileID();
61+
});
62+
return defs;
63+
}
64+
65+
namespace {
66+
/// This struct is the generator used when processing tablegen dialect
67+
/// interfaces.
68+
class DialectInterfaceGenerator {
69+
public:
70+
DialectInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
71+
: defs(getAllInterfaceDefinitions(records)), os(os) {}
72+
73+
bool emitInterfaceDecls();
74+
75+
protected:
76+
void emitInterfaceDecl(const Interface &interface);
77+
void emitInterfaceMethodsDef(const Interface &interface);
78+
79+
/// The set of interface records to emit.
80+
std::vector<const Record *> defs;
81+
// The stream to emit to.
82+
raw_ostream &os;
83+
};
84+
} // namespace
85+
86+
//===----------------------------------------------------------------------===//
87+
// GEN: Interface declarations
88+
//===----------------------------------------------------------------------===//
89+
90+
static void emitInterfaceMethodDoc(const InterfaceMethod &method,
91+
raw_ostream &os, StringRef prefix = "") {
92+
if (std::optional<StringRef> description = method.getDescription())
93+
tblgen::emitDescriptionComment(*description, os, prefix);
94+
}
95+
96+
static void emitInterfaceDeclMethods(const Interface &interface,
97+
raw_ostream &os) {
98+
for (auto &method : interface.getMethods()) {
99+
emitInterfaceMethodDoc(method, os, " ");
100+
os << " virtual ";
101+
emitCPPType(method.getReturnType(), os);
102+
emitMethodNameAndArgs(method, method.getName(), os);
103+
if (!method.getBody())
104+
// no default method body
105+
os << " = 0";
106+
os << ";\n";
107+
}
108+
}
109+
110+
void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
111+
llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
112+
113+
StringRef interfaceName = interface.getName();
114+
115+
tblgen::emitSummaryAndDescComments(os, "",
116+
interface.getDescription().value_or(""));
117+
118+
// Emit the main interface class declaration.
119+
os << llvm::formatv(
120+
"class {0} : public ::mlir::DialectInterface::Base<{0}> {{\n"
121+
"public:\n",
122+
interfaceName);
123+
124+
emitInterfaceDeclMethods(interface, os);
125+
os << llvm::formatv("\nprotected:\n"
126+
" {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n",
127+
interfaceName);
128+
129+
os << "};\n";
130+
}
131+
132+
void DialectInterfaceGenerator::emitInterfaceMethodsDef(
133+
const Interface &interface) {
134+
135+
for (auto &method : interface.getMethods()) {
136+
if (auto body = method.getBody()) {
137+
emitCPPType(method.getReturnType(), os);
138+
os << interface.getCppNamespace() << "::";
139+
os << interface.getName() << "::";
140+
emitMethodNameAndArgs(method, method.getName(), os);
141+
os << " {\n " << body.value() << "\n}\n";
142+
}
143+
}
144+
}
145+
146+
bool DialectInterfaceGenerator::emitInterfaceDecls() {
147+
148+
llvm::emitSourceFileHeader("Dialect Interface Declarations", os);
149+
150+
// Sort according to ID, so defs are emitted in the order in which they appear
151+
// in the Tablegen file.
152+
std::vector<const Record *> sortedDefs(defs);
153+
llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
154+
return lhs->getID() < rhs->getID();
155+
});
156+
157+
for (const Record *def : sortedDefs)
158+
emitInterfaceDecl(Interface(def));
159+
160+
os << "\n";
161+
for (const Record *def : sortedDefs)
162+
emitInterfaceMethodsDef(Interface(def));
163+
164+
return false;
165+
}
166+
167+
//===----------------------------------------------------------------------===//
168+
// GEN: Interface registration hooks
169+
//===----------------------------------------------------------------------===//
170+
171+
static mlir::GenRegistration genDecls(
172+
"gen-dialect-interface-decls",
173+
"Generate dialect interface declarations.",
174+
[](const RecordKeeper &records, raw_ostream &os) {
175+
return DialectInterfaceGenerator(records, os).emitInterfaceDecls();
176+
});

0 commit comments

Comments
 (0)