diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md index 7e1c5fe075675..d3e1888c48baf 100644 --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -85,6 +85,72 @@ if (DialectInlinerInterface *interface = dyn_cast(diale } ``` +#### Utilizing the ODS framework + +Note: Before reading this section, the reader should have some familiarity with +the concepts described in the +[`Operation Definition Specification`](DefiningDialects/Operations.md) documentation. + +MLIR also supports defining dialect interfaces directly in **TableGen**. +This reduces boilerplate and allows authors to specify high-level interface +structure declaratively. + +For example, the above interface can be defined using ODS as follows: + +```tablegen +def DialectInlinerInterface : DialectInterface<"DialectInlinerInterface"> { + let description = [{ + Define a base inlining interface class to allow for dialects to opt-in to + the inliner. + }]; + + let methods = [ + InterfaceMethod<[{ + Returns true if the given region 'src' can be inlined into the region + 'dest' that is attached to an operation registered to the current dialect. + 'valueMapping' contains any remapped values from within the 'src' region. + This can be used to examine what values will replace entry arguments into + the 'src' region, for example. + }], + "bool", "isLegalToInline", + (ins "Region *":$dest, "Region *":$src, "IRMapping &":$valueMapping), + [{ + return false; + }] + > + ]; +} +``` + +`DialectInterfaces` class make use of the following components: + +* C++ Class Name (Provided via template parameter) + - The name of the C++ interface class. +* Description (`description`) + - A string description of the interface, its invariants, example usages, + etc. +* C++ Namespace (`cppNamespace`) + - The C++ namespace that the interface class should be generated in. +* Methods (`methods`) + - The list of interface hook methods that are defined by the IR object. + - The structure of these methods is defined [here](#interface-methods). + +The header file can be generated via the following command: + +```bash +mlir-tblgen --gen-dialect-interface-decls DialectInterface.td +``` + +To generate dialect interface declarations using the ODS framework in CMake, you would write: + +```cmake +set(LLVM_TARGET_DEFINITIONS DialectInlinerInterface.td) +mlir_tablegen(DialectInlinerInterface.h.inc -gen-dialect-interface-decls) +``` + +An example of this can be found in the DialectInlinerInterface implementation +and the related `CMakeLists.txt` under `mlir/include/mlir/Transforms`. + #### DialectInterfaceCollection An additional utility is provided via `DialectInterfaceCollection`. This class @@ -364,10 +430,6 @@ void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, #### Utilizing the ODS Framework -Note: Before reading this section, the reader should have some familiarity with -the concepts described in the -[`Operation Definition Specification`](DefiningDialects/Operations.md) documentation. - As detailed above, [Interfaces](#attributeoperationtype-interfaces) allow for attributes, operations, and types to expose method calls without requiring that the caller know the specific derived type. The downside to this infrastructure, diff --git a/mlir/include/mlir/IR/Interfaces.td b/mlir/include/mlir/IR/Interfaces.td index 0cbe3fa25c9e7..e51bbd5620280 100644 --- a/mlir/include/mlir/IR/Interfaces.td +++ b/mlir/include/mlir/IR/Interfaces.td @@ -147,6 +147,11 @@ class TypeInterface baseInterfaces = []> !if(!empty(cppNamespace),"", cppNamespace # "::") # name >; +// DialectInterface represents a Dialect Interface. +class DialectInterface baseInterfaces = []> + : Interface, OpInterfaceTrait; + + // Whether to declare the interface methods in the user entity's header. This // class simply wraps an Interface but is used to indicate that the method // declarations should be generated. This class takes an optional set of methods diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h index 7c36cbc1192ac..f62d21da467a1 100644 --- a/mlir/include/mlir/TableGen/Interfaces.h +++ b/mlir/include/mlir/TableGen/Interfaces.h @@ -157,6 +157,13 @@ struct TypeInterface : public Interface { static bool classof(const Interface *interface); }; +// An interface that is registered to a Dialect. +struct DialectInterface : public Interface { + using Interface::Interface; + + static bool classof(const Interface *interface); +}; + } // namespace tblgen } // namespace mlir diff --git a/mlir/include/mlir/Transforms/CMakeLists.txt b/mlir/include/mlir/Transforms/CMakeLists.txt index 5fa52b28b6f1d..1b57a3482c1bb 100644 --- a/mlir/include/mlir/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Transforms/CMakeLists.txt @@ -5,4 +5,8 @@ mlir_tablegen(Transforms.capi.h.inc -gen-pass-capi-header --prefix Transforms) mlir_tablegen(Transforms.capi.cpp.inc -gen-pass-capi-impl --prefix Transforms) add_mlir_dialect_tablegen_target(MLIRTransformsPassIncGen) +set(LLVM_TARGET_DEFINITIONS DialectInlinerInterface.td) +mlir_tablegen(DialectInlinerInterface.h.inc -gen-dialect-interface-decls) +add_mlir_dialect_tablegen_target(MLIRTransformsDialectInterfaceIncGen) + add_mlir_doc(Passes GeneralPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Transforms/DialectInlinerInterface.td b/mlir/include/mlir/Transforms/DialectInlinerInterface.td new file mode 100644 index 0000000000000..0975b84179d3c --- /dev/null +++ b/mlir/include/mlir/Transforms/DialectInlinerInterface.td @@ -0,0 +1,196 @@ +#ifndef MLIR_INTERFACES_DIALECTINLINERINTERFACE +#define MLIR_INTERFACES_DIALECTINLINERINTERFACE + +include "mlir/IR/Interfaces.td" + +def DialectInlinerInterface : DialectInterface<"DialectInlinerInterface"> { + let description = [{ + This is the interface that must be implemented by the dialects of operations + to be inlined. This interface should only handle the operations of the + given dialect. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Returns true if the given operation 'callable', that implements the + 'CallableOpInterface', can be inlined into the position given call + operation 'call', that is registered to the current dialect and implements + the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the + given 'callable' is set to be cloned during the inlining process, or false + if the region is set to be moved in-place(i.e. no duplicates would be + created). + }], + "bool", "isLegalToInline", + (ins "::mlir::Operation *":$call, "::mlir::Operation *":$callable, + "bool":$wouldBeCloned), + [{ + return false; + }] + >, + InterfaceMethod<[{ + Returns true if the given region 'src' can be inlined into the region + 'dest' that is attached to an operation registered to the current dialect. + 'wouldBeCloned' is set to true if the given 'src' region is set to be + cloned during the inlining process, or false if the region is set to be + moved in-place (i.e. no duplicates would be created). 'valueMapping' + contains any remapped values from within the 'src' region. This can be + used to examine what values will replace entry arguments into the 'src' + region for example. + }], + "bool", "isLegalToInline", + (ins "::mlir::Region *":$dest, "::mlir::Region *":$src, "bool":$wouldBeCloned, + "::mlir::IRMapping &":$valueMapping), + [{ + return false; + }] + >, + InterfaceMethod<[{ + Returns true if the given region 'src' can be inlined into the region + 'dest' that is attached to an operation registered to the current dialect. + 'wouldBeCloned' is set to true if the given 'src' region is set to be + cloned during the inlining process, or false if the region is set to be + moved in-place(i.e. no duplicates would be created). 'valueMapping' + contains any remapped values from within the 'src' region. This can be + used to examine what values will replace entry arguments into the 'src' + region for example. + }], + "bool", "isLegalToInline", + (ins "::mlir::Operation *":$op, "::mlir::Region *":$dest, + "bool":$wouldBeCloned, "::mlir::IRMapping &":$valueMapping), + [{ + return false; + }] + >, + InterfaceMethod<[{ + This hook is invoked on an operation that contains regions. It should + return true if the analyzer should recurse within the regions of this + operation when computing legality and cost, false otherwise. The default + implementation returns true. + }], + "bool", "shouldAnalyzeRecursively", + (ins "::mlir::Operation *":$op), + [{ + return true; + }] + >, + InterfaceMethod<[{ + Handle the given inlined terminator by replacing it with a new operation + as necessary. This overload is called when the inlined region has more + than one block. The 'newDest' block represents the new final branching + destination of blocks within this region, i.e. operations that release + control to the parent operation will likely now branch to this block. + Its block arguments correspond to any values that need to be replaced by + terminators within the inlined region. + }], + "void", "handleTerminator", + (ins "::mlir::Operation *":$op, "::mlir::Block *":$newDest), + [{ + llvm_unreachable("must implement handleTerminator in the case of multiple " + "inlined blocks"); + }] + >, + InterfaceMethod<[{ + Handle the given inlined terminator by replacing it with a new operation + as necessary. This overload is called when the inlined region only + contains one block. 'valuesToReplace' contains the previously returned + values of the call site before inlining. These values must be replaced by + this callback if they had any users (for example for traditional function + calls, these are directly replaced with the operands of the `return` + operation). The given 'op' will be removed by the caller, after this + function has been called. + }], + "void", "handleTerminator", + (ins "::mlir::Operation *":$op, "::mlir::ValueRange":$valuesToReplace), + [{ + llvm_unreachable( + "must implement handleTerminator in the case of one inlined block"); + }] + >, + InterfaceMethod<[{ + Attempt to materialize a conversion for a type mismatch between a call + from this dialect, and a callable region. This method should generate an + operation that takes 'input' as the only operand, and produces a single + result of 'resultType'. If a conversion can not be generated, nullptr + should be returned. For example, this hook may be invoked in the following + scenarios: + + ```mlir + func @foo(i32) -> i32 { ... } + + // Mismatched input operand ... = foo.call @foo(%input : i16) -> i32 + + // Mismatched result type. + ... = foo.call @foo(%input : i32) -> i16 + ``` + + NOTE: This hook may be invoked before the 'isLegal' checks above. + }], + "::mlir::Operation *", "materializeCallConversion", + (ins "::mlir::OpBuilder &":$builder, "::mlir::Value":$input, + "::mlir::Type":$resultType, "::mlir::Location":$conversionLoc), + [{ + return nullptr; + }] + >, + InterfaceMethod<[{ + Hook to transform the call arguments before using them to replace the + callee arguments. Returns a value of the same type or the `argument` + itself if nothing changed. The `argumentAttrs` dictionary is non-null even + if no attribute is present. The hook is called after converting the + callsite argument types using the materializeCallConversion callback, and + right before inlining the callee region. Any operations created using the + provided `builder` are inserted right before the inlined callee region. An + example use case is the insertion of copies for by value arguments. + }], + "::mlir::Value", "handleArgument", + (ins "::mlir::OpBuilder &":$builder, "::mlir::Operation *":$call, + "::mlir::Operation *":$callable, "::mlir::Value":$argument, + "::mlir::DictionaryAttr":$argumentAttrs), + [{ + return argument; + }] + >, + InterfaceMethod<[{ + Hook to transform the callee results before using them to replace the call + results. Returns a value of the same type or the `result` itself if + nothing changed. The `resultAttrs` dictionary is non-null even if no + attribute is present. The hook is called right before handling + terminators, and obtains the callee result before converting its type + using the `materializeCallConversion` callback. Any operations created + using the provided `builder` are inserted right after the inlined callee + region. An example use case is the insertion of copies for by value + results. NOTE: This hook is invoked after inlining the `callable` region. + }], + "::mlir::Value", "handleResult", + (ins "::mlir::OpBuilder &":$builder, "::mlir::Operation *":$call, + "::mlir::Operation *":$callable, "::mlir::Value":$result, + "::mlir::DictionaryAttr":$resultAttrs), + [{ + return result; + }] + >, + InterfaceMethod<[{ + Process a set of blocks that have been inlined for a call. This callback + is invoked before inlined terminator operations have been processed. + }], + "void", "processInlinedCallBlocks", + (ins "::mlir::Operation *":$call, + "::mlir::iterator_range<::mlir::Region::iterator>":$inlinedBlocks), + [{}] + >, + InterfaceMethod<[{ + Returns true if the inliner can assume a fast path of not creating a new + block, if there is only one block. + }], + "bool", "allowSingleBlockOptimization", + (ins "::mlir::iterator_range<::mlir::Region::iterator>":$inlinedBlocks), + [{ + return true; + }] + > + ]; +} + + +#endif diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h index ed6413d8cd44c..b6c6da3ddcc9b 100644 --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -32,158 +32,7 @@ class Region; class TypeRange; class Value; class ValueRange; - -//===----------------------------------------------------------------------===// -// InlinerInterface -//===----------------------------------------------------------------------===// - -/// This is the interface that must be implemented by the dialects of operations -/// to be inlined. This interface should only handle the operations of the -/// given dialect. -class DialectInlinerInterface - : public DialectInterface::Base { -public: - DialectInlinerInterface(Dialect *dialect) : Base(dialect) {} - - //===--------------------------------------------------------------------===// - // Analysis Hooks - //===--------------------------------------------------------------------===// - - /// Returns true if the given operation 'callable', that implements the - /// 'CallableOpInterface', can be inlined into the position given call - /// operation 'call', that is registered to the current dialect and implements - /// the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the - /// given 'callable' is set to be cloned during the inlining process, or false - /// if the region is set to be moved in-place(i.e. no duplicates would be - /// created). - virtual bool isLegalToInline(Operation *call, Operation *callable, - bool wouldBeCloned) const { - return false; - } - - /// Returns true if the given region 'src' can be inlined into the region - /// 'dest' that is attached to an operation registered to the current dialect. - /// 'wouldBeCloned' is set to true if the given 'src' region is set to be - /// cloned during the inlining process, or false if the region is set to be - /// moved in-place(i.e. no duplicates would be created). 'valueMapping' - /// contains any remapped values from within the 'src' region. This can be - /// used to examine what values will replace entry arguments into the 'src' - /// region for example. - virtual bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, - IRMapping &valueMapping) const { - return false; - } - - /// Returns true if the given operation 'op', that is registered to this - /// dialect, can be inlined into the given region, false otherwise. - /// 'wouldBeCloned' is set to true if the given 'op' is set to be cloned - /// during the inlining process, or false if the operation is set to be moved - /// in-place(i.e. no duplicates would be created). 'valueMapping' contains any - /// remapped values from within the 'src' region. This can be used to examine - /// what values may potentially replace the operands to 'op'. - virtual bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, - IRMapping &valueMapping) const { - return false; - } - - /// This hook is invoked on an operation that contains regions. It should - /// return true if the analyzer should recurse within the regions of this - /// operation when computing legality and cost, false otherwise. The default - /// implementation returns true. - virtual bool shouldAnalyzeRecursively(Operation *op) const { return true; } - - //===--------------------------------------------------------------------===// - // Transformation Hooks - //===--------------------------------------------------------------------===// - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. This overload is called when the inlined region has more - /// than one block. The 'newDest' block represents the new final branching - /// destination of blocks within this region, i.e. operations that release - /// control to the parent operation will likely now branch to this block. - /// Its block arguments correspond to any values that need to be replaced by - /// terminators within the inlined region. - virtual void handleTerminator(Operation *op, Block *newDest) const { - llvm_unreachable("must implement handleTerminator in the case of multiple " - "inlined blocks"); - } - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. This overload is called when the inlined region only - /// contains one block. 'valuesToReplace' contains the previously returned - /// values of the call site before inlining. These values must be replaced by - /// this callback if they had any users (for example for traditional function - /// calls, these are directly replaced with the operands of the `return` - /// operation). The given 'op' will be removed by the caller, after this - /// function has been called. - virtual void handleTerminator(Operation *op, - ValueRange valuesToReplace) const { - llvm_unreachable( - "must implement handleTerminator in the case of one inlined block"); - } - - /// Attempt to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. For example, this hook may be invoked in the following - /// scenarios: - /// func @foo(i32) -> i32 { ... } - /// - /// // Mismatched input operand - /// ... = foo.call @foo(%input : i16) -> i32 - /// - /// // Mismatched result type. - /// ... = foo.call @foo(%input : i32) -> i16 - /// - /// NOTE: This hook may be invoked before the 'isLegal' checks above. - virtual Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const { - return nullptr; - } - - /// Hook to transform the call arguments before using them to replace the - /// callee arguments. Returns a value of the same type or the `argument` - /// itself if nothing changed. The `argumentAttrs` dictionary is non-null even - /// if no attribute is present. The hook is called after converting the - /// callsite argument types using the materializeCallConversion callback, and - /// right before inlining the callee region. Any operations created using the - /// provided `builder` are inserted right before the inlined callee region. An - /// example use case is the insertion of copies for by value arguments. - virtual Value handleArgument(OpBuilder &builder, Operation *call, - Operation *callable, Value argument, - DictionaryAttr argumentAttrs) const { - return argument; - } - - /// Hook to transform the callee results before using them to replace the call - /// results. Returns a value of the same type or the `result` itself if - /// nothing changed. The `resultAttrs` dictionary is non-null even if no - /// attribute is present. The hook is called right before handling - /// terminators, and obtains the callee result before converting its type - /// using the `materializeCallConversion` callback. Any operations created - /// using the provided `builder` are inserted right after the inlined callee - /// region. An example use case is the insertion of copies for by value - /// results. NOTE: This hook is invoked after inlining the `callable` region. - virtual Value handleResult(OpBuilder &builder, Operation *call, - Operation *callable, Value result, - DictionaryAttr resultAttrs) const { - return result; - } - - /// Process a set of blocks that have been inlined for a call. This callback - /// is invoked before inlined terminator operations have been processed. - virtual void processInlinedCallBlocks( - Operation *call, iterator_range inlinedBlocks) const {} - - /// Returns true if the inliner can assume a fast path of not creating a new - /// block, if there is only one block. - virtual bool allowSingleBlockOptimization( - iterator_range inlinedBlocks) const { - return true; - } -}; +class DialectInlinerInterface; /// This interface provides the hooks into the inlining interface. /// Note: this class automatically collects 'DialectInlinerInterface' objects @@ -307,4 +156,6 @@ inlineCall(InlinerInterface &interface, } // namespace mlir +#include "mlir/Transforms/DialectInlinerInterface.h.inc" + #endif // MLIR_TRANSFORMS_INLININGUTILS_H diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp index b0ad3ee59a089..77a6cecebbeaf 100644 --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -208,3 +208,11 @@ bool OpInterface::classof(const Interface *interface) { bool TypeInterface::classof(const Interface *interface) { return interface->getDef().isSubClassOf("TypeInterface"); } + +//===----------------------------------------------------------------------===// +// DialectInterface +//===----------------------------------------------------------------------===// + +bool DialectInterface::classof(const Interface *interface) { + return interface->getDef().isSubClassOf("DialectInterface"); +} diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 54b67f5c7a91e..6c5303b4dd8a4 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -27,6 +27,7 @@ add_mlir_library(MLIRTransforms DEPENDS MLIRTransformsPassIncGen + MLIRTransformsDialectInterfaceIncGen LINK_LIBS PUBLIC MLIRAnalysis diff --git a/mlir/test/mlir-tblgen/dialect-interface.td b/mlir/test/mlir-tblgen/dialect-interface.td new file mode 100644 index 0000000000000..ff39fd941f300 --- /dev/null +++ b/mlir/test/mlir-tblgen/dialect-interface.td @@ -0,0 +1,65 @@ +// RUN: mlir-tblgen -gen-dialect-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL + +include "mlir/IR/Interfaces.td" + +def NoDefaultMethod : DialectInterface<"NoDefaultMethod"> { + let description = [{ + This is an example dialect interface without default method body. + }]; + + let cppNamespace = "::mlir::example"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Check if it's an example dialect", + /*returnType=*/ "bool", + /*methodName=*/ "isExampleDialect", + /*args=*/ (ins) + >, + InterfaceMethod< + /*desc=*/ "second method to check if multiple methods supported", + /*returnType=*/ "unsigned", + /*methodName=*/ "supportSecondMethod", + /*args=*/ (ins "::mlir::Type":$type) + > + + ]; +} + +// DECL: class NoDefaultMethod : public {{.*}}DialectInterface::Base +// DECL: public: +// DECL-NEXT: NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {} +// DECL: virtual bool isExampleDialect() const {} +// DECL: virtual unsigned supportSecondMethod(::mlir::Type type) const {} + +def WithDefaultMethodInterface : DialectInterface<"WithDefaultMethodInterface"> { + let description = [{ + This is an example dialect interface with default method bodies. + }]; + + let cppNamespace = "::mlir::example"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Check if it's an example dialect", + /*returnType=*/ "bool", + /*methodName=*/ "isExampleDialect", + /*args=*/ (ins), + /*methodBody=*/ [{ + return true; + }] + >, + InterfaceMethod< + /*desc=*/ "second method to check if multiple methods supported", + /*returnType=*/ "unsigned", + /*methodName=*/ "supportSecondMethod", + /*args=*/ (ins "::mlir::Type":$type) + > + + ]; +} + +// DECL: virtual bool isExampleDialect() const { +// DECL-NEXT: return true; +// DECL-NEXT: } + diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt index 2a7ef7e0576c8..d7087cba3c874 100644 --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -12,6 +12,7 @@ add_tablegen(mlir-tblgen MLIR AttrOrTypeFormatGen.cpp BytecodeDialectGen.cpp DialectGen.cpp + DialectInterfacesGen.cpp DirectiveCommonGen.cpp EnumsGen.cpp EnumPythonBindingGen.cpp diff --git a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp new file mode 100644 index 0000000000000..1d3b24a7aee15 --- /dev/null +++ b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp @@ -0,0 +1,164 @@ +//===- DialectInterfacesGen.cpp - MLIR dialect interface utility generator ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// DialectInterfaceGen generates definitions for Dialect interfaces. +// +//===----------------------------------------------------------------------===// + +#include "CppGenUtilities.h" +#include "DocGenUtilities.h" +#include "mlir/Support/IndentedOstream.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Interfaces.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/CodeGenHelpers.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +using namespace mlir; +using llvm::Record; +using llvm::RecordKeeper; +using mlir::tblgen::Interface; +using mlir::tblgen::InterfaceMethod; + +/// Emit a string corresponding to a C++ type, followed by a space if necessary. +static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) { + type = type.trim(); + os << type; + if (type.back() != '&' && type.back() != '*') + os << " "; + return os; +} + +/// Emit the method name and argument list for the given method. +static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name, + raw_ostream &os) { + os << name << '('; + llvm::interleaveComma(method.getArguments(), os, + [&](const InterfaceMethod::Argument &arg) { + os << arg.type << " " << arg.name; + }); + os << ") const"; +} + +/// Get an array of all Dialect Interface definitions +static std::vector +getAllInterfaceDefinitions(const RecordKeeper &records) { + std::vector defs = + records.getAllDerivedDefinitions("DialectInterface"); + + llvm::erase_if(defs, [&](const Record *def) { + // Ignore interfaces defined outside of the top-level file. + return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) != + llvm::SrcMgr.getMainFileID(); + }); + return defs; +} + +namespace { +/// This struct is the generator used when processing tablegen dialect +/// interfaces. +class DialectInterfaceGenerator { +public: + DialectInterfaceGenerator(const RecordKeeper &records, raw_ostream &os) + : defs(getAllInterfaceDefinitions(records)), os(os) {} + + bool emitInterfaceDecls(); + +protected: + void emitInterfaceDecl(const Interface &interface); + + /// The set of interface records to emit. + std::vector defs; + // The stream to emit to. + raw_ostream &os; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// GEN: Interface declarations +//===----------------------------------------------------------------------===// + +static void emitInterfaceMethodDoc(const InterfaceMethod &method, + raw_ostream &os, StringRef prefix = "") { + if (std::optional description = method.getDescription()) + tblgen::emitDescriptionComment(*description, os, prefix); +} + +static void emitInterfaceMethodsDef(const Interface &interface, + raw_ostream &os) { + + raw_indented_ostream ios(os); + ios.indent(2); + + for (auto &method : interface.getMethods()) { + emitInterfaceMethodDoc(method, ios); + ios << "virtual "; + emitCPPType(method.getReturnType(), ios); + emitMethodNameAndArgs(method, method.getName(), ios); + ios << " {"; + + if (auto body = method.getBody()) { + ios << "\n"; + ios.indent(4); + ios << body << "\n"; + ios.indent(2); + } + os << "}\n"; + } +} + +void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) { + llvm::NamespaceEmitter ns(os, interface.getCppNamespace()); + + StringRef interfaceName = interface.getName(); + + tblgen::emitSummaryAndDescComments(os, "", + interface.getDescription().value_or("")); + + // Emit the main interface class declaration. + os << llvm::formatv( + "class {0} : public ::mlir::DialectInterface::Base<{0}> {\n" + "public:\n" + " {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n", + interfaceName); + + emitInterfaceMethodsDef(interface, os); + + os << "};\n"; +} + +bool DialectInterfaceGenerator::emitInterfaceDecls() { + + llvm::emitSourceFileHeader("Dialect Interface Declarations", os); + + // Sort according to ID, so defs are emitted in the order in which they appear + // in the Tablegen file. + std::vector sortedDefs(defs); + llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) { + return lhs->getID() < rhs->getID(); + }); + + for (const Record *def : sortedDefs) + emitInterfaceDecl(Interface(def)); + + return false; +} + +//===----------------------------------------------------------------------===// +// GEN: Interface registration hooks +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration genDecls( + "gen-dialect-interface-decls", "Generate dialect interface declarations.", + [](const RecordKeeper &records, raw_ostream &os) { + return DialectInterfaceGenerator(records, os).emitInterfaceDecls(); + });