Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/IR/Interfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ class TypeInterface<string name, list<Interface> baseInterfaces = []>
!if(!empty(cppNamespace),"", cppNamespace # "::") # name
>;

// DialectInterface represents an interface registered to an operation.
class DialectInterface<string name, list<Interface> baseInterfaces = []>
: Interface<name, baseInterfaces>, OpInterfaceTrait<name>;


// Whether to declare the interface methods in the user entity's header. This
// class simply wraps an Interface but is used to indicate that the method
// declarations should be generated. This class takes an optional set of methods
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/TableGen/Interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ struct TypeInterface : public Interface {

static bool classof(const Interface *interface);
};
// An interface that is registered to a Dialect.
struct DialectInterface : public Interface {
using Interface::Interface;

static bool classof(const Interface *interface);
};

} // namespace tblgen
} // namespace mlir

Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/TableGen/Interfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,11 @@ bool OpInterface::classof(const Interface *interface) {
bool TypeInterface::classof(const Interface *interface) {
return interface->getDef().isSubClassOf("TypeInterface");
}

//===----------------------------------------------------------------------===//
// DialectInterface
//===----------------------------------------------------------------------===//

bool DialectInterface::classof(const Interface *interface) {
return interface->getDef().isSubClassOf("DialectInterface");
}
66 changes: 66 additions & 0 deletions mlir/test/mlir-tblgen/dialect-interface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// RUN: mlir-tblgen -gen-dialect-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL

include "mlir/IR/Interfaces.td"

def NoDefaultMethod : DialectInterface<"NoDefaultMethod"> {
let description = [{
This is an example dialect interface without default method body.
}];

let cppNamespace = "::mlir::example";

let methods = [
InterfaceMethod<
/*desc=*/ "Check if it's an example dialect",
/*returnType=*/ "bool",
/*methodName=*/ "isExampleDialect",
/*args=*/ (ins)
>,
InterfaceMethod<
/*desc=*/ "second method to check if multiple methods supported",
/*returnType=*/ "unsigned",
/*methodName=*/ "supportSecondMethod",
/*args=*/ (ins "::mlir::Type":$type)
>

];
}

// DECL: class NoDefaultMethod : public {{.*}}DialectInterface::Base<NoDefaultMethod>
// DECL: virtual bool isExampleDialect() const = 0;
// DECL: virtual unsigned supportSecondMethod(::mlir::Type type) const = 0;
// DECL: protected:
// DECL-NEXT: NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {}

def WithDefaultMethodInterface : DialectInterface<"WithDefaultMethodInterface"> {
let description = [{
This is an example dialect interface with default method bodies.
}];

let cppNamespace = "::mlir::example";

let methods = [
InterfaceMethod<
/*desc=*/ "Check if it's an example dialect",
/*returnType=*/ "bool",
/*methodName=*/ "isExampleDialect",
/*args=*/ (ins),
/*methodBody=*/ [{
return true;
}]
>,
InterfaceMethod<
/*desc=*/ "second method to check if multiple methods supported",
/*returnType=*/ "unsigned",
/*methodName=*/ "supportSecondMethod",
/*args=*/ (ins "::mlir::Type":$type)
>

];
}

// DECL: virtual bool isExampleDialect() const;
// DECL: bool ::mlir::example::WithDefaultMethodInterface::isExampleDialect() const {
// DECL-NEXT: return true;
// DECL-NEXT: }

1 change: 1 addition & 0 deletions mlir/tools/mlir-tblgen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_tablegen(mlir-tblgen MLIR
AttrOrTypeFormatGen.cpp
BytecodeDialectGen.cpp
DialectGen.cpp
DialectInterfacesGen.cpp
DirectiveCommonGen.cpp
EnumsGen.cpp
EnumPythonBindingGen.cpp
Expand Down
175 changes: 175 additions & 0 deletions mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
//===- DialectInterfacesGen.cpp - MLIR dialect interface utility generator ===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// DialectInterfaceGen generates definitions for Dialect interfaces.
//
//===----------------------------------------------------------------------===//

#include "CppGenUtilities.h"
#include "DocGenUtilities.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/CodeGenHelpers.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"

using namespace mlir;
using llvm::Record;
using llvm::RecordKeeper;
using mlir::tblgen::Interface;
using mlir::tblgen::InterfaceMethod;

/// Emit a string corresponding to a C++ type, followed by a space if necessary.
static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
type = type.trim();
os << type;
if (type.back() != '&' && type.back() != '*')
os << " ";
return os;
}

/// Emit the method name and argument list for the given method.
static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
raw_ostream &os) {
os << name << '(';
llvm::interleaveComma(method.getArguments(), os,
[&](const InterfaceMethod::Argument &arg) {
os << arg.type << " " << arg.name;
});
os << ") const";
}

/// Get an array of all Dialect Interface definitions
static std::vector<const Record *>
getAllInterfaceDefinitions(const RecordKeeper &records) {
std::vector<const Record *> defs =
records.getAllDerivedDefinitions("DialectInterface");

llvm::erase_if(defs, [&](const Record *def) {
// Ignore interfaces defined outside of the top-level file.
return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
llvm::SrcMgr.getMainFileID();
});
return defs;
}

namespace {
/// This struct is the generator used when processing tablegen dialect
/// interfaces.
class DialectInterfaceGenerator {
public:
DialectInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
: defs(getAllInterfaceDefinitions(records)), os(os) {}

bool emitInterfaceDecls();

protected:
void emitInterfaceDecl(const Interface &interface);
void emitInterfaceMethodsDef(const Interface &interface);

/// The set of interface records to emit.
std::vector<const Record *> defs;
// The stream to emit to.
raw_ostream &os;
};
} // namespace

//===----------------------------------------------------------------------===//
// GEN: Interface declarations
//===----------------------------------------------------------------------===//

static void emitInterfaceMethodDoc(const InterfaceMethod &method,
raw_ostream &os, StringRef prefix = "") {
if (std::optional<StringRef> description = method.getDescription())
tblgen::emitDescriptionComment(*description, os, prefix);
}

static void emitInterfaceDeclMethods(const Interface &interface,
raw_ostream &os) {
for (auto &method : interface.getMethods()) {
emitInterfaceMethodDoc(method, os, " ");
os << " virtual ";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No change needed, but you could also use indented ostream here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used indented ostream. Can you please review again to see if that's the correct usage?

emitCPPType(method.getReturnType(), os);
emitMethodNameAndArgs(method, method.getName(), os);
if (!method.getBody())
// no default method body
os << " = 0";
os << ";\n";
}
}

void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
llvm::NamespaceEmitter ns(os, interface.getCppNamespace());

StringRef interfaceName = interface.getName();

tblgen::emitSummaryAndDescComments(os, "",
interface.getDescription().value_or(""));

// Emit the main interface class declaration.
os << llvm::formatv(
"class {0} : public ::mlir::DialectInterface::Base<{0}> {{\n"
"public:\n",
interfaceName);

emitInterfaceDeclMethods(interface, os);
os << llvm::formatv("\nprotected:\n"
" {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n",
interfaceName);

os << "};\n";
}

void DialectInterfaceGenerator::emitInterfaceMethodsDef(
const Interface &interface) {

for (auto &method : interface.getMethods()) {
if (auto body = method.getBody()) {
emitCPPType(method.getReturnType(), os);
os << interface.getCppNamespace() << "::";
os << interface.getName() << "::";
emitMethodNameAndArgs(method, method.getName(), os);
os << " {\n " << body.value() << "\n}\n";
}
}
}

bool DialectInterfaceGenerator::emitInterfaceDecls() {

llvm::emitSourceFileHeader("Dialect Interface Declarations", os);

// Sort according to ID, so defs are emitted in the order in which they appear
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could have sworn I had this exposed somewhere :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I grabbed this one from OpInterfacesGen.cpp. I haven't done any research if this has been implemented somewhere else. Do you remember where it might be implemented?

// in the Tablegen file.
std::vector<const Record *> sortedDefs(defs);
llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
return lhs->getID() < rhs->getID();
});

for (const Record *def : sortedDefs)
emitInterfaceDecl(Interface(def));

os << "\n";
for (const Record *def : sortedDefs)
emitInterfaceMethodsDef(Interface(def));

return false;
}

//===----------------------------------------------------------------------===//
// GEN: Interface registration hooks
//===----------------------------------------------------------------------===//

static mlir::GenRegistration genDecls(
"gen-dialect-interface-decls", "Generate dialect interface declarations.",
[](const RecordKeeper &records, raw_ostream &os) {
return DialectInterfaceGenerator(records, os).emitInterfaceDecls();
});