diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index daedc2be22588..44493b75b8a8c 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -68,6 +68,7 @@ list(INSERT CMAKE_MODULE_PATH 0 ) include(AddMLIR) +include(IRDLToCpp) # -BSymbolic is incompatible with TypeID if("${CMAKE_SHARED_LINKER_FLAGS}" MATCHES "-Bsymbolic[^-]") @@ -211,6 +212,7 @@ include_directories(BEFORE # Adding tools/mlir-tblgen here as calling add_tablegen sets some variables like # MLIR_TABLEGEN_EXE in PARENT_SCOPE which gets lost if that folder is included # from another directory like tools +add_subdirectory(tools/mlir-irdl-to-cpp) add_subdirectory(tools/mlir-linalg-ods-gen) add_subdirectory(tools/mlir-pdll) add_subdirectory(tools/mlir-tblgen) diff --git a/mlir/cmake/modules/CMakeLists.txt b/mlir/cmake/modules/CMakeLists.txt index 96c27044c4e8d..47349f0c7eb62 100644 --- a/mlir/cmake/modules/CMakeLists.txt +++ b/mlir/cmake/modules/CMakeLists.txt @@ -112,6 +112,8 @@ set(MLIR_CONFIG_TABLEGEN_EXE mlir-tblgen) set(MLIR_CONFIG_PDLL_TABLEGEN_EXE mlir-pdll) set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE mlir-src-sharder) +set(MLIR_CONFIG_IRDL_TO_CPP_EXE mlir-irdl-to-cpp) + configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in ${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/MLIRConfig.cmake @@ -135,6 +137,7 @@ if (NOT LLVM_INSTALL_TOOLCHAIN_ONLY) ${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/MLIRConfigVersion.cmake ${CMAKE_CURRENT_SOURCE_DIR}/AddMLIR.cmake ${CMAKE_CURRENT_SOURCE_DIR}/AddMLIRPython.cmake + ${CMAKE_CURRENT_SOURCE_DIR}/IRDLToCpp.cmake ${CMAKE_CURRENT_SOURCE_DIR}/MLIRDetectPythonEnv.cmake DESTINATION ${MLIR_INSTALL_PACKAGE_DIR} COMPONENT mlir-cmake-exports) diff --git a/mlir/cmake/modules/IRDLToCpp.cmake b/mlir/cmake/modules/IRDLToCpp.cmake new file mode 100644 index 0000000000000..8470ccdf55166 --- /dev/null +++ b/mlir/cmake/modules/IRDLToCpp.cmake @@ -0,0 +1,12 @@ +function(add_irdl_to_cpp_target target irdl_file) + add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${irdl_file}.cpp.inc + COMMAND ${MLIR_IRDL_TO_CPP_EXE} ${CMAKE_CURRENT_SOURCE_DIR}/${irdl_file} -o ${CMAKE_CURRENT_BINARY_DIR}/${irdl_file}.cpp.inc + + # The command output depends on the executable to ensure IRDL sources are properly rebuilt + # if the tool changes. + DEPENDS ${MLIR_IRDL_TO_CPP_EXE} ${CMAKE_CURRENT_SOURCE_DIR}/${irdl_file} + COMMENT "Building ${irdl_file}..." + ) + add_custom_target(${target} DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${irdl_file}.cpp.inc) +endfunction() diff --git a/mlir/cmake/modules/MLIRConfig.cmake.in b/mlir/cmake/modules/MLIRConfig.cmake.in index c695b5787af66..71f3e028b1e88 100644 --- a/mlir/cmake/modules/MLIRConfig.cmake.in +++ b/mlir/cmake/modules/MLIRConfig.cmake.in @@ -12,6 +12,7 @@ set(MLIR_INCLUDE_DIRS "@MLIR_CONFIG_INCLUDE_DIRS@") set(MLIR_TABLEGEN_EXE "@MLIR_CONFIG_TABLEGEN_EXE@") set(MLIR_PDLL_TABLEGEN_EXE "@MLIR_CONFIG_PDLL_TABLEGEN_EXE@") set(MLIR_SRC_SHARDER_TABLEGEN_EXE "@MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE@") +set(MLIR_IRDL_TO_CPP_EXE "@MLIR_CONFIG_IRDL_TO_CPP_EXE@") set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@") set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@") set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@") diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td index 78b24ecd20773..4a83eb62fba32 100644 --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td @@ -91,6 +91,7 @@ def IRDL_TypeOp : IRDL_Op<"type", let regions = (region SizedRegion<1>:$body); let assemblyFormat = "$sym_name attr-dict-with-keyword custom($body)"; + let hasVerifier = 1; } def IRDL_AttributeOp : IRDL_Op<"attribute", @@ -126,6 +127,8 @@ def IRDL_AttributeOp : IRDL_Op<"attribute", let regions = (region SizedRegion<1>:$body); let assemblyFormat = "$sym_name attr-dict-with-keyword custom($body)"; + + let hasVerifier = 1; } def IRDL_ParametersOp : IRDL_Op<"parameters", @@ -202,6 +205,7 @@ def IRDL_OperationOp : IRDL_Op<"operation", let assemblyFormat = "$sym_name attr-dict-with-keyword custom($body)"; let hasRegionVerifier = true; + let hasVerifier = 1; } def IRDL_OperandsOp : IRDL_Op<"operands", [HasParent<"OperationOp">]> { diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h index 3de3e02ff3f81..1ab80fb27fa9a 100644 --- a/mlir/include/mlir/InitAllTranslations.h +++ b/mlir/include/mlir/InitAllTranslations.h @@ -14,6 +14,8 @@ #ifndef MLIR_INITALLTRANSLATIONS_H #define MLIR_INITALLTRANSLATIONS_H +#include "mlir/Target/IRDLToCpp/TranslationRegistration.h" + namespace mlir { void registerFromLLVMIRTranslation(); @@ -33,6 +35,7 @@ inline void registerAllTranslations() { static bool initOnce = []() { registerFromLLVMIRTranslation(); registerFromSPIRVTranslation(); + registerIRDLToCppTranslation(); registerToCppTranslation(); registerToLLVMIRTranslation(); registerToSPIRVTranslation(); diff --git a/mlir/include/mlir/Target/IRDLToCpp/IRDLToCpp.h b/mlir/include/mlir/Target/IRDLToCpp/IRDLToCpp.h new file mode 100644 index 0000000000000..940ce5632cb41 --- /dev/null +++ b/mlir/include/mlir/Target/IRDLToCpp/IRDLToCpp.h @@ -0,0 +1,38 @@ +//===- IRDLToCpp.h - Register translation -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the registration function for the IRDL to C++ translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_IRDLTOCPP_IRDLTOCPP_H +#define MLIR_TARGET_IRDLTOCPP_IRDLTOCPP_H + +#include "mlir/Dialect/IRDL/IR/IRDL.h" + +namespace mlir { +namespace irdl { + +/// Translates an IRDL dialect definition to a C++ definition that can be used +/// with MLIR. +/// +/// The following preprocessor macros will generate the following code: +/// +/// // This define generates code for the dialect's class declarations +/// #define GEN_DIALECT_DECL_HEADER +/// +/// // This define generates code for the dialect's class definitions +/// #define GEN_DIALECT_DEF +LogicalResult +translateIRDLDialectToCpp(llvm::ArrayRef dialects, + raw_ostream &output); + +} // namespace irdl +} // namespace mlir + +#endif // MLIR_TARGET_IRDLTOCPP_IRDLTOCPP_H diff --git a/mlir/include/mlir/Target/IRDLToCpp/TranslationRegistration.h b/mlir/include/mlir/Target/IRDLToCpp/TranslationRegistration.h new file mode 100644 index 0000000000000..430c519570862 --- /dev/null +++ b/mlir/include/mlir/Target/IRDLToCpp/TranslationRegistration.h @@ -0,0 +1,22 @@ +//===- TranslationRegistration.h - Register translation ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the registration function for the IRDL to C++ translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_IRDLTOCPP_TRANSLATIONREGISTRATION_H +#define MLIR_TARGET_IRDLTOCPP_TRANSLATIONREGISTRATION_H + +namespace mlir { + +void registerIRDLToCppTranslation(); + +} // namespace mlir + +#endif // MLIR_TARGET_IRDLTOCPP_TRANSLATIONREGISTRATION_H diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp index c0778d478619a..bcc9f0b109ac2 100644 --- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp @@ -74,13 +74,62 @@ static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op, if (!region.getBlocks().front().empty()) p.printRegion(region); } +static llvm::LogicalResult isValidName(llvm::StringRef in, mlir::Operation *loc, + const Twine &label) { + if (in.empty()) + return loc->emitError("name of ") << label << " is empty"; + + bool allowUnderscore = false; + for (auto &elem : in) { + if (elem == '_') { + if (!allowUnderscore) + return loc->emitError("name of ") + << label << " should not contain leading or double underscores"; + } else { + if (!isalnum(elem)) + return loc->emitError("name of ") + << label + << " must contain only lowercase letters, digits and " + "underscores"; + + if (llvm::isUpper(elem)) + return loc->emitError("name of ") + << label << " should not contain uppercase letters"; + } + + allowUnderscore = elem != '_'; + } + + return success(); +} LogicalResult DialectOp::verify() { if (!Dialect::isValidNamespace(getName())) return emitOpError("invalid dialect name"); + if (failed(isValidName(getSymName(), getOperation(), "dialect"))) + return failure(); + return success(); } +LogicalResult OperationOp::verify() { + return isValidName(getSymName(), getOperation(), "operation"); +} + +LogicalResult TypeOp::verify() { + auto symName = getSymName(); + if (symName.front() == '!') + symName = symName.substr(1); + return isValidName(symName, getOperation(), "type"); +} + +LogicalResult AttributeOp::verify() { + auto symName = getSymName(); + if (symName.front() == '#') + symName = symName.substr(1); + return isValidName(symName, getOperation(), "attribute"); +} + LogicalResult OperationOp::verifyRegions() { // Stores pairs of value kinds and the list of names of values of this kind in // the operation. @@ -133,18 +182,10 @@ static LogicalResult verifyNames(Operation *op, StringRef kindName, DenseMap nameMap; for (auto [i, name] : llvm::enumerate(names)) { StringRef nameRef = llvm::cast(name).getValue(); - if (nameRef.empty()) - return op->emitOpError() - << "name of " << kindName << " #" << i << " is empty"; - if (!llvm::isAlpha(nameRef[0]) && nameRef[0] != '_') - return op->emitOpError() - << "name of " << kindName << " #" << i - << " must start with either a letter or an underscore"; - if (llvm::any_of(nameRef, - [](char c) { return !llvm::isAlnum(c) && c != '_'; })) - return op->emitOpError() - << "name of " << kindName << " #" << i - << " must contain only letters, digits and underscores"; + + if (failed(isValidName(nameRef, op, Twine(kindName) + " #" + Twine(i)))) + return failure(); + if (nameMap.contains(nameRef)) return op->emitOpError() << "name of " << kindName << " #" << i << " is a duplicate of the name of " << kindName diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt index f14ec49b5a0c2..6eb0abc214d38 100644 --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Cpp) +add_subdirectory(IRDLToCpp) add_subdirectory(SPIRV) add_subdirectory(LLVMIR) add_subdirectory(LLVM) diff --git a/mlir/lib/Target/IRDLToCpp/CMakeLists.txt b/mlir/lib/Target/IRDLToCpp/CMakeLists.txt new file mode 100644 index 0000000000000..731f3fea37451 --- /dev/null +++ b/mlir/lib/Target/IRDLToCpp/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_translation_library(MLIRTargetIRDLToCpp + TranslationRegistration.cpp + IRDLToCpp.cpp + Templates/DialectDecl.txt + Templates/DialectDef.txt + Templates/Header.txt + Templates/PerOperationDecl.txt + Templates/PerOperationDef.txt + Templates/TypeDecl.txt + Templates/TypeDef.txt + Templates/TypeHeaderDecl.txt + Templates/TypeHeaderDef.txt + + LINK_LIBS PUBLIC + MLIRIR + MLIRIRDL + MLIRTranslateLib +) diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp new file mode 100644 index 0000000000000..538ea4f450b29 --- /dev/null +++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp @@ -0,0 +1,585 @@ +//===- IRDLToCpp.cpp - Converts IRDL definitions to C++ -------------------===// +// +// Part of the LLVM Project, under the A0ache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/IRDLToCpp/IRDLToCpp.h" +#include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" + +#include "TemplatingUtils.h" + +using namespace mlir; + +constexpr char headerTemplateText[] = +#include "Templates/Header.txt" + ; + +constexpr char declarationMacroFlag[] = "GEN_DIALECT_DECL_HEADER"; +constexpr char definitionMacroFlag[] = "GEN_DIALECT_DEF"; + +namespace { + +/// The set of strings that can be generated from a Dialect declaraiton +struct DialectStrings { + std::string dialectName; + std::string dialectCppName; + std::string dialectCppShortName; + std::string dialectBaseTypeName; + + std::string namespaceOpen; + std::string namespaceClose; + std::string namespacePath; +}; + +/// The set of strings that can be generated from a Type declaraiton +struct TypeStrings { + StringRef typeName; + std::string typeCppName; +}; + +/// The set of strings that can be generated from an Operation declaraiton +struct OpStrings { + StringRef opName; + std::string opCppName; + SmallVector opResultNames; + SmallVector opOperandNames; +}; + +static std::string joinNameList(llvm::ArrayRef names) { + std::string nameArray; + llvm::raw_string_ostream nameArrayStream(nameArray); + nameArrayStream << "{\"" << llvm::join(names, "\", \"") << "\"}"; + + return nameArray; +} + +/// Generates the C++ type name for a TypeOp +static std::string typeToCppName(irdl::TypeOp type) { + return llvm::formatv("{0}Type", + convertToCamelFromSnakeCase(type.getSymName(), true)); +} + +/// Generates the C++ class name for an OperationOp +static std::string opToCppName(irdl::OperationOp op) { + return llvm::formatv("{0}Op", + convertToCamelFromSnakeCase(op.getSymName(), true)); +} + +/// Generates TypeStrings from a TypeOp +static TypeStrings getStrings(irdl::TypeOp type) { + TypeStrings strings; + strings.typeName = type.getSymName(); + strings.typeCppName = typeToCppName(type); + return strings; +} + +/// Generates OpStrings from an OperatioOp +static OpStrings getStrings(irdl::OperationOp op) { + auto operandOp = op.getOp(); + + auto resultOp = op.getOp(); + + OpStrings strings; + strings.opName = op.getSymName(); + strings.opCppName = opToCppName(op); + + if (operandOp) { + strings.opOperandNames = SmallVector( + llvm::map_range(operandOp->getNames(), [](Attribute attr) { + return llvm::formatv("{0}", cast(attr)); + })); + } + + if (resultOp) { + strings.opResultNames = SmallVector( + llvm::map_range(resultOp->getNames(), [](Attribute attr) { + return llvm::formatv("{0}", cast(attr)); + })); + } + + return strings; +} + +/// Fills a dictionary with values from TypeStrings +static void fillDict(irdl::detail::dictionary &dict, + const TypeStrings &strings) { + dict["TYPE_NAME"] = strings.typeName; + dict["TYPE_CPP_NAME"] = strings.typeCppName; +} + +/// Fills a dictionary with values from OpStrings +static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) { + const auto operandCount = strings.opOperandNames.size(); + const auto resultCount = strings.opResultNames.size(); + + dict["OP_NAME"] = strings.opName; + dict["OP_CPP_NAME"] = strings.opCppName; + dict["OP_OPERAND_COUNT"] = std::to_string(strings.opOperandNames.size()); + dict["OP_RESULT_COUNT"] = std::to_string(strings.opResultNames.size()); + dict["OP_OPERAND_INITIALIZER_LIST"] = + operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}"; + dict["OP_RESULT_INITIALIZER_LIST"] = + resultCount ? joinNameList(strings.opResultNames) : "{\"\"}"; +} + +/// Fills a dictionary with values from DialectStrings +static void fillDict(irdl::detail::dictionary &dict, + const DialectStrings &strings) { + dict["DIALECT_NAME"] = strings.dialectName; + dict["DIALECT_BASE_TYPE_NAME"] = strings.dialectBaseTypeName; + dict["DIALECT_CPP_NAME"] = strings.dialectCppName; + dict["DIALECT_CPP_SHORT_NAME"] = strings.dialectCppShortName; + dict["NAMESPACE_OPEN"] = strings.namespaceOpen; + dict["NAMESPACE_CLOSE"] = strings.namespaceClose; + dict["NAMESPACE_PATH"] = strings.namespacePath; +} + +static LogicalResult generateTypedefList(irdl::DialectOp &dialect, + SmallVector &typeNames) { + auto typeOps = dialect.getOps(); + auto range = llvm::map_range(typeOps, typeToCppName); + typeNames = SmallVector(range); + return success(); +} + +static LogicalResult generateOpList(irdl::DialectOp &dialect, + SmallVector &opNames) { + auto operationOps = dialect.getOps(); + auto range = llvm::map_range(operationOps, opToCppName); + opNames = SmallVector(range); + return success(); +} + +} // namespace + +static LogicalResult generateTypeInclude(irdl::TypeOp type, raw_ostream &output, + irdl::detail::dictionary &dict) { + static const auto typeDeclTemplate = irdl::detail::Template( +#include "Templates/TypeDecl.txt" + ); + + fillDict(dict, getStrings(type)); + typeDeclTemplate.render(output, dict); + + return success(); +} + +static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, + const OpStrings &opStrings) { + auto opGetters = std::string{}; + auto resGetters = std::string{}; + + for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) { + const auto op = + llvm::convertToCamelFromSnakeCase(opStrings.opOperandNames[i], true); + opGetters += llvm::formatv("::mlir::Value get{0}() { return " + "getStructuredOperands({1}).front(); }\n ", + op, i); + } + for (size_t i = 0, end = opStrings.opResultNames.size(); i < end; ++i) { + const auto op = + llvm::convertToCamelFromSnakeCase(opStrings.opResultNames[i], true); + resGetters += llvm::formatv( + R"(::mlir::Value get{0}() { return ::llvm::cast<::mlir::Value>(getStructuredResults({1}).front()); } + )", + op, i); + } + + dict["OP_OPERAND_GETTER_DECLS"] = opGetters; + dict["OP_RESULT_GETTER_DECLS"] = resGetters; +} + +static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, + const OpStrings &opStrings) { + std::string buildDecls; + llvm::raw_string_ostream stream{buildDecls}; + + auto resultParams = + llvm::join(llvm::map_range(opStrings.opResultNames, + [](StringRef name) -> std::string { + return llvm::formatv( + "::mlir::Type {0}, ", + llvm::convertToCamelFromSnakeCase(name)); + }), + ""); + + auto operandParams = + llvm::join(llvm::map_range(opStrings.opOperandNames, + [](StringRef name) -> std::string { + return llvm::formatv( + "::mlir::Value {0}, ", + llvm::convertToCamelFromSnakeCase(name)); + }), + ""); + + stream << llvm::formatv( + R"(static void build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {0} {1} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)", + resultParams, operandParams); + dict["OP_BUILD_DECLS"] = buildDecls; +} + +static LogicalResult generateOperationInclude(irdl::OperationOp op, + raw_ostream &output, + irdl::detail::dictionary &dict) { + static const auto perOpDeclTemplate = irdl::detail::Template( +#include "Templates/PerOperationDecl.txt" + ); + const auto opStrings = getStrings(op); + fillDict(dict, opStrings); + + generateOpGetterDeclarations(dict, opStrings); + generateOpBuilderDeclarations(dict, opStrings); + + perOpDeclTemplate.render(output, dict); + return success(); +} + +static LogicalResult generateInclude(irdl::DialectOp dialect, + raw_ostream &output, + DialectStrings &dialectStrings) { + static const auto dialectDeclTemplate = irdl::detail::Template( +#include "Templates/DialectDecl.txt" + ); + static const auto typeHeaderDeclTemplate = irdl::detail::Template( +#include "Templates/TypeHeaderDecl.txt" + ); + + irdl::detail::dictionary dict; + fillDict(dict, dialectStrings); + + dialectDeclTemplate.render(output, dict); + typeHeaderDeclTemplate.render(output, dict); + + auto typeOps = dialect.getOps(); + auto operationOps = dialect.getOps(); + + for (auto &&typeOp : typeOps) { + if (failed(generateTypeInclude(typeOp, output, dict))) + return failure(); + } + + SmallVector opNames; + if (failed(generateOpList(dialect, opNames))) + return failure(); + + auto classDeclarations = + llvm::join(llvm::map_range(opNames, + [](llvm::StringRef name) -> std::string { + return llvm::formatv("class {0};", name); + }), + "\n"); + const auto forwardDeclarations = llvm::formatv( + "{1}\n{0}\n{2}", std::move(classDeclarations), + dialectStrings.namespaceOpen, dialectStrings.namespaceClose); + + output << forwardDeclarations; + for (auto &&operationOp : operationOps) { + if (failed(generateOperationInclude(operationOp, output, dict))) + return failure(); + } + + return success(); +} + +static std::string generateOpDefinition(irdl::detail::dictionary &dict, + irdl::OperationOp op) { + static const auto perOpDefTemplate = mlir::irdl::detail::Template{ +#include "Templates/PerOperationDef.txt" + }; + + auto opStrings = getStrings(op); + fillDict(dict, opStrings); + + const auto operandCount = opStrings.opOperandNames.size(); + const auto operandNames = + operandCount ? joinNameList(opStrings.opOperandNames) : "{\"\"}"; + + const auto resultNames = joinNameList(opStrings.opResultNames); + + auto resultTypes = llvm::join( + llvm::map_range(opStrings.opResultNames, + [](StringRef attr) -> std::string { + return llvm::formatv("::mlir::Type {0}, ", attr); + }), + ""); + auto operandTypes = llvm::join( + llvm::map_range(opStrings.opOperandNames, + [](StringRef attr) -> std::string { + return llvm::formatv("::mlir::Value {0}, ", attr); + }), + ""); + auto operandAdder = + llvm::join(llvm::map_range(opStrings.opOperandNames, + [](StringRef attr) -> std::string { + return llvm::formatv( + " opState.addOperands({0});", attr); + }), + "\n"); + auto resultAdder = llvm::join( + llvm::map_range(opStrings.opResultNames, + [](StringRef attr) -> std::string { + return llvm::formatv(" opState.addTypes({0});", attr); + }), + "\n"); + + const auto buildDefinition = llvm::formatv( + R"( +void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{ +{3} +{4} +} +)", + opStrings.opCppName, std::move(resultTypes), std::move(operandTypes), + std::move(operandAdder), std::move(resultAdder)); + + dict["OP_BUILD_DEFS"] = buildDefinition; + + std::string str; + llvm::raw_string_ostream stream{str}; + perOpDefTemplate.render(stream, dict); + return str; +} + +static std::string +generateTypeVerifierCase(StringRef name, const DialectStrings &dialectStrings) { + return llvm::formatv( + R"(.Case({1}::{0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) { +value = {1}::{0}::get(parser.getContext()); +return ::mlir::success(!!value); +}))", + name, dialectStrings.namespacePath); +} + +static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output, + DialectStrings &dialectStrings) { + + static const auto typeHeaderDefTemplate = mlir::irdl::detail::Template{ +#include "Templates/TypeHeaderDef.txt" + }; + static const auto typeDefTemplate = mlir::irdl::detail::Template{ +#include "Templates/TypeDef.txt" + }; + static const auto dialectDefTemplate = mlir::irdl::detail::Template{ +#include "Templates/DialectDef.txt" + }; + + irdl::detail::dictionary dict; + fillDict(dict, dialectStrings); + + typeHeaderDefTemplate.render(output, dict); + + SmallVector typeNames; + if (failed(generateTypedefList(dialect, typeNames))) + return failure(); + + dict["TYPE_LIST"] = llvm::join( + llvm::map_range(typeNames, + [&dialectStrings](llvm::StringRef name) -> std::string { + return llvm::formatv( + "{0}::{1}", dialectStrings.namespacePath, name); + }), + ",\n"); + + auto typeVerifierGenerator = + [&dialectStrings](llvm::StringRef name) -> std::string { + return generateTypeVerifierCase(name, dialectStrings); + }; + + auto typeCase = + llvm::join(llvm::map_range(typeNames, typeVerifierGenerator), "\n"); + + dict["TYPE_PARSER"] = llvm::formatv( + R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) { + return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser) + {0} + .Default([&](llvm::StringRef keyword, llvm::SMLoc) {{ + *mnemonic = keyword; + return std::nullopt; + }); +})", + std::move(typeCase)); + + auto typePrintCase = + llvm::join(llvm::map_range(typeNames, + [&](llvm::StringRef name) -> std::string { + return llvm::formatv( + R"(.Case<{1}::{0}>([&](auto t) { + printer << {1}::{0}::getMnemonic(); + return ::mlir::success(); + }))", + name, dialectStrings.namespacePath); + }), + "\n"); + dict["TYPE_PRINTER"] = llvm::formatv( + R"(static ::llvm::LogicalResult generatedTypePrinter(::mlir::Type def, ::mlir::AsmPrinter &printer) { + return ::llvm::TypeSwitch<::mlir::Type, ::llvm::LogicalResult>(def) + {0} + .Default([](auto) {{ return ::mlir::failure(); }); +})", + std::move(typePrintCase)); + + dict["TYPE_DEFINES"] = + join(map_range(typeNames, + [&](StringRef name) -> std::string { + return formatv("MLIR_DEFINE_EXPLICIT_TYPE_ID({1}::{0})", + name, dialectStrings.namespacePath); + }), + "\n"); + + typeDefTemplate.render(output, dict); + + auto operations = dialect.getOps(); + SmallVector opNames; + if (failed(generateOpList(dialect, opNames))) + return failure(); + + const auto commaSeparatedOpList = llvm::join( + map_range(opNames, + [&dialectStrings](llvm::StringRef name) -> std::string { + return llvm::formatv("{0}::{1}", dialectStrings.namespacePath, + name); + }), + ",\n"); + + const auto opDefinitionGenerator = [&dict](irdl::OperationOp op) { + return generateOpDefinition(dict, op); + }; + + const auto perOpDefinitions = + llvm::join(llvm::map_range(operations, opDefinitionGenerator), "\n"); + + dict["OP_LIST"] = commaSeparatedOpList; + dict["OP_CLASSES"] = perOpDefinitions; + output << perOpDefinitions; + dialectDefTemplate.render(output, dict); + + return success(); +} + +static LogicalResult verifySupported(irdl::DialectOp dialect) { + LogicalResult res = success(); + dialect.walk([&](mlir::Operation *op) { + res = + llvm::TypeSwitch(op) + .Case(([](irdl::DialectOp) { return success(); })) + .Case( + ([](irdl::OperationOp) { return success(); })) + .Case(([](irdl::TypeOp) { return success(); })) + .Case(([](irdl::OperandsOp op) -> LogicalResult { + if (llvm::all_of( + op.getVariadicity(), [](irdl::VariadicityAttr attr) { + return attr.getValue() == irdl::Variadicity::single; + })) + return success(); + return op.emitError("IRDL C++ translation does not yet support " + "variadic operations"); + })) + .Case(([](irdl::ResultsOp op) -> LogicalResult { + if (llvm::all_of( + op.getVariadicity(), [](irdl::VariadicityAttr attr) { + return attr.getValue() == irdl::Variadicity::single; + })) + return success(); + return op.emitError( + "IRDL C++ translation does not yet support variadic results"); + })) + .Case(([](irdl::AnyOp) { return success(); })) + .Default([](mlir::Operation *op) -> LogicalResult { + return op->emitError("IRDL C++ translation does not yet support " + "translation of ") + << op->getName() << " operation"; + }); + + if (failed(res)) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + return res; +} + +LogicalResult +irdl::translateIRDLDialectToCpp(llvm::ArrayRef dialects, + raw_ostream &output) { + static const auto typeDefTempl = detail::Template( +#include "Templates/TypeDef.txt" + ); + + llvm::SmallMapVector dialectStringTable; + + for (auto dialect : dialects) { + if (failed(verifySupported(dialect))) + return failure(); + + StringRef dialectName = dialect.getSymName(); + + SmallVector> namespaceAbsolutePath{{"mlir"}, dialectName}; + std::string namespaceOpen; + std::string namespaceClose; + std::string namespacePath; + llvm::raw_string_ostream namespaceOpenStream(namespaceOpen); + llvm::raw_string_ostream namespaceCloseStream(namespaceClose); + llvm::raw_string_ostream namespacePathStream(namespacePath); + for (auto &pathElement : namespaceAbsolutePath) { + namespaceOpenStream << "namespace " << pathElement << " {\n"; + namespaceCloseStream << "} // namespace " << pathElement << "\n"; + namespacePathStream << "::" << pathElement; + } + + std::string cppShortName = + llvm::convertToCamelFromSnakeCase(dialectName, true); + std::string dialectBaseTypeName = llvm::formatv("{0}Type", cppShortName); + std::string cppName = llvm::formatv("{0}Dialect", cppShortName); + + DialectStrings dialectStrings; + dialectStrings.dialectName = dialectName; + dialectStrings.dialectBaseTypeName = dialectBaseTypeName; + dialectStrings.dialectCppName = cppName; + dialectStrings.dialectCppShortName = cppShortName; + dialectStrings.namespaceOpen = namespaceOpen; + dialectStrings.namespaceClose = namespaceClose; + dialectStrings.namespacePath = namespacePath; + + dialectStringTable[dialect] = std::move(dialectStrings); + } + + // generate the actual header + output << headerTemplateText; + + output << llvm::formatv("#ifdef {0}\n#undef {0}\n", declarationMacroFlag); + for (auto dialect : dialects) { + + auto &dialectStrings = dialectStringTable[dialect]; + auto &dialectName = dialectStrings.dialectName; + + if (failed(generateInclude(dialect, output, dialectStrings))) + return dialect->emitError("Error in Dialect " + dialectName + + " while generating headers"); + } + output << llvm::formatv("#endif // #ifdef {}\n", declarationMacroFlag); + + output << llvm::formatv("#ifdef {0}\n#undef {0}\n ", definitionMacroFlag); + for (auto &dialect : dialects) { + auto &dialectStrings = dialectStringTable[dialect]; + auto &dialectName = dialectStrings.dialectName; + + if (failed(generateLib(dialect, output, dialectStrings))) + return dialect->emitError("Error in Dialect " + dialectName + + " while generating library"); + } + output << llvm::formatv("#endif // #ifdef {}\n", definitionMacroFlag); + + return success(); +} diff --git a/mlir/lib/Target/IRDLToCpp/Templates/DialectDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/DialectDecl.txt new file mode 100644 index 0000000000000..e0f85feadd196 --- /dev/null +++ b/mlir/lib/Target/IRDLToCpp/Templates/DialectDecl.txt @@ -0,0 +1,28 @@ +R"( +__NAMESPACE_OPEN__ + +class __DIALECT_CPP_NAME__ : public ::mlir::Dialect { + explicit __DIALECT_CPP_NAME__(::mlir::MLIRContext *context); + + void initialize(); + friend class ::mlir::MLIRContext; + +public: + ~__DIALECT_CPP_NAME__() override; + static constexpr ::llvm::StringLiteral getDialectNamespace() { + return ::llvm::StringLiteral("__DIALECT_NAME__"); + } + + /// Parse a type registered to this dialect. + ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; + + /// Print a type registered to this dialect. + void printType(::mlir::Type type, + ::mlir::DialectAsmPrinter &os) const override; +}; + +__NAMESPACE_CLOSE__ + +MLIR_DECLARE_EXPLICIT_TYPE_ID(__NAMESPACE_PATH__::__DIALECT_CPP_NAME__) + +)" diff --git a/mlir/lib/Target/IRDLToCpp/Templates/DialectDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/DialectDef.txt new file mode 100644 index 0000000000000..cd52de11dd6b0 --- /dev/null +++ b/mlir/lib/Target/IRDLToCpp/Templates/DialectDef.txt @@ -0,0 +1,27 @@ +R"( + +__NAMESPACE_OPEN__ + +__DIALECT_CPP_NAME__::__DIALECT_CPP_NAME__(::mlir::MLIRContext *context) + : ::mlir::Dialect(getDialectNamespace(), context, + ::mlir::TypeID::get<__DIALECT_CPP_NAME__>()) +{ + initialize(); +} + +__DIALECT_CPP_NAME__::~__DIALECT_CPP_NAME__() = default; + + +void __DIALECT_CPP_NAME__::initialize() { + addOperations< + __OP_LIST__ + >(); + addTypes< + __TYPE_LIST__ + >(); +} + +__NAMESPACE_CLOSE__ + +MLIR_DEFINE_EXPLICIT_TYPE_ID(__NAMESPACE_PATH__::__DIALECT_CPP_NAME__) +)" diff --git a/mlir/lib/Target/IRDLToCpp/Templates/Header.txt b/mlir/lib/Target/IRDLToCpp/Templates/Header.txt new file mode 100644 index 0000000000000..c081d6431778a --- /dev/null +++ b/mlir/lib/Target/IRDLToCpp/Templates/Header.txt @@ -0,0 +1,10 @@ +R"( +/*===- IRDL to C++ Generated file -------------------------------*- C++ -*-===*\ +|* *| +|* Dialect Declarations *| +|* *| +|* Automatically generated file, do not edit! *| +|* *| +\*===----------------------------------------------------------------------===*/ + +)" diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt new file mode 100644 index 0000000000000..9e787b8056a89 --- /dev/null +++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt @@ -0,0 +1,185 @@ +R"( +//===----------------------------------------------------------------------===// +// __NAMESPACE_PATH__::__OP_CPP_NAME__ declarations +//===----------------------------------------------------------------------===// + +__NAMESPACE_OPEN__ + +namespace detail { + +class __OP_CPP_NAME__GenericAdaptorBase { +public: + struct Properties { + }; +public: + __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op) + : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()), + odsRegions(op->getRegions()) + {} + + /// Return the unstructured operand index of a structured operand along with + // the amount of unstructured operands it contains. + std::pair + getStructuredOperandIndexAndLength (unsigned index, + unsigned odsOperandsSize) { + return {index, 1}; + } + + const Properties &getProperties() { + return properties; + } + + ::mlir::DictionaryAttr getAttributes() { + return odsAttrs; + } +protected: + ::mlir::DictionaryAttr odsAttrs; + ::std::optional<::mlir::OperationName> odsOpName; + Properties properties; + ::mlir::RegionRange odsRegions; +}; + +} // namespace detail + +template +class __OP_CPP_NAME__GenericAdaptor + : public detail::__OP_CPP_NAME__GenericAdaptorBase { + using ValueT = ::llvm::detail::ValueOfRange; + using Base = detail::__OP_CPP_NAME__GenericAdaptorBase; +public: + __OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, + ::mlir::OpaqueProperties properties, + ::mlir::RegionRange regions = {}) + : __OP_CPP_NAME__GenericAdaptor(values, attrs, + (properties ? *properties.as<::mlir::EmptyProperties *>() + : ::mlir::EmptyProperties{}), regions) {} + + __OP_CPP_NAME__GenericAdaptor(RangeT values, + const __OP_CPP_NAME__GenericAdaptorBase &base) + : Base(base), odsOperands(values) {} + + // This template parameter allows using __OP_CPP_NAME__ which is declared + // later. + template >> + __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op) + : Base(op), odsOperands(values) {} + + /// Return the unstructured operand index of a structured operand along with + // the amount of unstructured operands it contains. + std::pair + getStructuredOperandIndexAndLength(unsigned index) { + return Base::getStructuredOperandIndexAndLength(index, odsOperands.size()); + } + + /// Get the n-th structured operand (single value, variadic or optional). + RangeT getStructuredOperands(unsigned index) { + auto valueRange = getStructuredOperandIndexAndLength(index); + return {std::next(odsOperands.begin(), valueRange.first), + std::next(odsOperands.begin(), + valueRange.first + valueRange.second)}; + } + + RangeT getOperands() { + return odsOperands; + } + + __OP_OPERAND_GETTER_DECLS__ + +private: + RangeT odsOperands; +}; + +class __OP_CPP_NAME__Adaptor + : public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> { +public: + using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor; + __OP_CPP_NAME__Adaptor(__OP_CPP_NAME__ op); + + ::llvm::LogicalResult verify(::mlir::Location loc); +}; + +class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> { +public: + using Op::Op; + using Op::print; + using Adaptor = __OP_CPP_NAME__Adaptor; + template + using GenericAdaptor = __OP_CPP_NAME__GenericAdaptor; + using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; + static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { + return {}; + } + + static constexpr ::llvm::StringLiteral getOperationName() { + return ::llvm::StringLiteral("__DIALECT_NAME__.__OP_NAME__"); + } + + static ::llvm::ArrayRef<::llvm::StringRef> getOperandNames() { + static ::llvm::StringRef operandNames[] = __OP_OPERAND_INITIALIZER_LIST__; + return operandNames; + } + + static ::llvm::StringRef getOperandName(unsigned index) { + assert(index < __OP_OPERAND_COUNT__ && "invalid attribute index"); + return getOperandNames()[index]; + } + + static ::llvm::ArrayRef<::llvm::StringRef> getResultNames() { + static ::llvm::StringRef resultNames[] = __OP_RESULT_INITIALIZER_LIST__; + return resultNames; + } + + static ::llvm::StringRef getResultName(unsigned index) { + assert(index < __OP_RESULT_COUNT__ && "invalid attribute index"); + return getResultNames()[index]; + } + + /// Return the unstructured operand index of a structured operand along with + // the amount of unstructured operands it contains. + std::pair + getStructuredOperandIndexAndLength(unsigned index) { + return {index, 1}; + } + + /// Get the n-th structured operand (single value, variadic or optional). + ::mlir::Operation::operand_range getStructuredOperands(unsigned index) { + auto valueRange = getStructuredOperandIndexAndLength(index); + return {std::next(getOperation()->operand_begin(), valueRange.first), + std::next(getOperation()->operand_begin(), + valueRange.first + valueRange.second)}; + } + + /// Return the unstructured result index of a structured result along with + // the amount of unstructured results it contains. + std::pair + getStructuredResultIndexAndLength(unsigned index) { + return {index, 1}; + } + + /// Get the n-th structured result (single value, variadic or optional). + ::mlir::Operation::result_range getStructuredResults(unsigned index) { + auto valueRange = getStructuredResultIndexAndLength(index); + return {std::next(getOperation()->result_begin(), valueRange.first), + std::next(getOperation()->result_begin(), + valueRange.first + valueRange.second)}; + } + + __OP_OPERAND_GETTER_DECLS__ + __OP_RESULT_GETTER_DECLS__ + + __OP_BUILD_DECLS__ + static void build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, + ::mlir::TypeRange resultTypes, + ::mlir::ValueRange operands, + ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +}; + + +__NAMESPACE_CLOSE__ + +MLIR_DECLARE_EXPLICIT_TYPE_ID(__NAMESPACE_PATH__::__OP_CPP_NAME__) + +)" diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt new file mode 100644 index 0000000000000..88f19d0f22ca4 --- /dev/null +++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt @@ -0,0 +1,28 @@ +R"( + +//===----------------------------------------------------------------------===// +// __NAMESPACE_PATH__::__OP_CPP_NAME__ definitions +//===----------------------------------------------------------------------===// + +__NAMESPACE_OPEN__ + +__OP_BUILD_DEFS__ + +void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, + ::mlir::TypeRange resultTypes, + ::mlir::ValueRange operands, + ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) +{ + assert(operands.size() == __OP_OPERAND_COUNT__); + assert(resultTypes.size() == __OP_RESULT_COUNT__); + odsState.addOperands(operands); + odsState.addAttributes(attributes); + odsState.addTypes(resultTypes); +} + + +__NAMESPACE_CLOSE__ + +MLIR_DEFINE_EXPLICIT_TYPE_ID(__NAMESPACE_PATH__::__OP_CPP_NAME__) +)" diff --git a/mlir/lib/Target/IRDLToCpp/Templates/TypeDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/TypeDecl.txt new file mode 100644 index 0000000000000..201a550a29ef0 --- /dev/null +++ b/mlir/lib/Target/IRDLToCpp/Templates/TypeDecl.txt @@ -0,0 +1,26 @@ +R"( + +//===----------------------------------------------------------------------===// +// __NAMESPACE_PATH__::__TYPE_CPP_NAME__ declarations +//===----------------------------------------------------------------------===// + +__NAMESPACE_OPEN__ + +class __TYPE_CPP_NAME__ + : public ::mlir::Type::TypeBase<__TYPE_CPP_NAME__, + __DIALECT_BASE_TYPE_NAME__, + ::mlir::TypeStorage> { +public: + using Base::Base; + static constexpr ::llvm::StringLiteral name = + "__DIALECT_NAME__.__TYPE_NAME__"; + static constexpr ::llvm::StringLiteral dialectName = "__DIALECT_NAME__"; + static constexpr ::llvm::StringLiteral getMnemonic() { + return {"__TYPE_NAME__"}; + } +}; + +__NAMESPACE_CLOSE__ + +MLIR_DECLARE_EXPLICIT_TYPE_ID(__NAMESPACE_PATH__::__TYPE_CPP_NAME__) +)" diff --git a/mlir/lib/Target/IRDLToCpp/Templates/TypeDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/TypeDef.txt new file mode 100644 index 0000000000000..aaca4430507fb --- /dev/null +++ b/mlir/lib/Target/IRDLToCpp/Templates/TypeDef.txt @@ -0,0 +1,35 @@ +R"( + +__TYPE_DEFINES__ + +__NAMESPACE_OPEN__ + +__TYPE_PARSER__ + +__TYPE_PRINTER__ + +/// Parse a type registered to this dialect. +::mlir::Type __DIALECT_CPP_NAME__::parseType(::mlir::DialectAsmParser &parser) + const +{ + ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); + ::llvm::StringRef mnemonic; + ::mlir::Type genType; + auto parseResult = generatedTypeParser(parser, &mnemonic, genType); + if (parseResult.has_value()) + return genType; + + parser.emitError(typeLoc) << "unknown type `" + << mnemonic << "` in dialect `" << getNamespace() << "`"; + return {}; +} + +/// Print a type registered to this dialect. +void __DIALECT_CPP_NAME__::printType(::mlir::Type type, + ::mlir::DialectAsmPrinter &printer) const { + ((void) generatedTypePrinter(type, printer)); + +} + +__NAMESPACE_CLOSE__ +)" diff --git a/mlir/lib/Target/IRDLToCpp/Templates/TypeHeaderDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/TypeHeaderDecl.txt new file mode 100644 index 0000000000000..878bbdb7d3bd2 --- /dev/null +++ b/mlir/lib/Target/IRDLToCpp/Templates/TypeHeaderDecl.txt @@ -0,0 +1,14 @@ +R"( +__NAMESPACE_OPEN__ + +class __DIALECT_BASE_TYPE_NAME__ : public ::mlir::Type { +public: + using Type::Type; + + static bool classof(Type type); +}; + +__NAMESPACE_CLOSE__ + +MLIR_DECLARE_EXPLICIT_TYPE_ID(__NAMESPACE_PATH__::__DIALECT_BASE_TYPE_NAME__) +)" diff --git a/mlir/lib/Target/IRDLToCpp/Templates/TypeHeaderDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/TypeHeaderDef.txt new file mode 100644 index 0000000000000..be5ba2f390fce --- /dev/null +++ b/mlir/lib/Target/IRDLToCpp/Templates/TypeHeaderDef.txt @@ -0,0 +1,9 @@ +R"( +__NAMESPACE_OPEN__ + +bool __DIALECT_BASE_TYPE_NAME__::classof(Type type) { + return llvm::isa<__DIALECT_CPP_NAME__>(type.getDialect()); +} + +__NAMESPACE_CLOSE__ +)" diff --git a/mlir/lib/Target/IRDLToCpp/TemplatingUtils.h b/mlir/lib/Target/IRDLToCpp/TemplatingUtils.h new file mode 100644 index 0000000000000..aeb3c09543213 --- /dev/null +++ b/mlir/lib/Target/IRDLToCpp/TemplatingUtils.h @@ -0,0 +1,89 @@ +//===- TemplatingUtils.h - Templater for text templates -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_TARGET_IRDLTOCPP_TEMPLATINGUTILS_H +#define MLIR_LIB_TARGET_IRDLTOCPP_TEMPLATINGUTILS_H + +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace mlir::irdl::detail { + +/// A dictionary stores a mapping of template variable names to their assigned +/// string values. +using dictionary = llvm::StringMap>; + +/// Template Code as used by IRDL-to-Cpp. +/// +/// For efficiency, produces a bytecode representation of an input template. +/// - LiteralToken: A contiguous stream of characters to be printed +/// - ReplacementToken: A template variable that will be replaced +class Template { +public: + Template(llvm::StringRef str) { + bool processingReplacementToken = false; + while (!str.empty()) { + auto [token, remainder] = str.split("__"); + + if (processingReplacementToken) { + assert(!token.empty() && "replacement name cannot be empty"); + bytecode.emplace_back(ReplacementToken{token}); + } else { + if (!token.empty()) + bytecode.emplace_back(LiteralToken{token}); + } + + processingReplacementToken = !processingReplacementToken; + str = remainder; + } + } + + /// Render will apply a dictionary to the Template and send the rendered + /// result to the specified output stream. + void render(llvm::raw_ostream &out, const dictionary &replacements) const { + for (auto instruction : bytecode) { + if (auto *inst = std::get_if(&instruction)) { + out << inst->text; + continue; + } + + if (auto *inst = std::get_if(&instruction)) { + auto replacement = replacements.find(inst->keyName); +#ifndef NDEBUG + if (replacement == replacements.end()) { + llvm::errs() << "Missing template key: " << inst->keyName << "\n"; + llvm_unreachable("Missing template key"); + } +#endif + out << replacement->second; + continue; + } + + llvm_unreachable("non-exhaustive bytecode visit"); + } + } + +private: + struct LiteralToken { + llvm::StringRef text; + }; + + struct ReplacementToken { + llvm::StringRef keyName; + }; + + std::vector> bytecode; +}; + +} // namespace mlir::irdl::detail + +#endif // MLIR_LIB_TARGET_IRDLTOCPP_TEMPLATINGUTILS_H diff --git a/mlir/lib/Target/IRDLToCpp/TranslationRegistration.cpp b/mlir/lib/Target/IRDLToCpp/TranslationRegistration.cpp new file mode 100644 index 0000000000000..2a991662738ba --- /dev/null +++ b/mlir/lib/Target/IRDLToCpp/TranslationRegistration.cpp @@ -0,0 +1,49 @@ +//===- TranslationRegistration.cpp - Register translation -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/IRDLToCpp/TranslationRegistration.h" +#include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Target/IRDLToCpp/IRDLToCpp.h" +#include "mlir/Tools/mlir-translate/Translation.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" + +using namespace mlir; + +namespace mlir { + +//===----------------------------------------------------------------------===// +// Translation registration +//===----------------------------------------------------------------------===// + +void registerIRDLToCppTranslation() { + TranslateFromMLIRRegistration reg( + "irdl-to-cpp", "translate IRDL dialect definitions to C++ definitions", + [](Operation *op, raw_ostream &output) { + return TypeSwitch(op) + .Case([&](irdl::DialectOp dialectOp) { + return irdl::translateIRDLDialectToCpp(dialectOp, output); + }) + .Case([&](ModuleOp moduleOp) { + for (Operation &op : moduleOp.getBody()->getOperations()) + if (auto dialectOp = llvm::dyn_cast(op)) + if (failed( + irdl::translateIRDLDialectToCpp(dialectOp, output))) + return failure(); + return success(); + }) + .Default([](Operation *op) { + return op->emitError( + "unsupported operation for IRDL to C++ translation"); + }); + }, + [](DialectRegistry ®istry) { registry.insert(); }); +} + +} // namespace mlir diff --git a/mlir/test/Dialect/IRDL/invalid.irdl.mlir b/mlir/test/Dialect/IRDL/invalid.irdl.mlir index 1e66161b800df..8a7fffe1a9cbd 100644 --- a/mlir/test/Dialect/IRDL/invalid.irdl.mlir +++ b/mlir/test/Dialect/IRDL/invalid.irdl.mlir @@ -25,7 +25,7 @@ irdl.dialect @testd { irdl.dialect @testd { irdl.type @type { %0 = irdl.any - // expected-error@+1 {{name of parameter #0 must contain only letters, digits and underscores}} + // expected-error@+1 {{name of parameter #0 must contain only lowercase letters, digits and underscores}} irdl.parameters(test$test: %0) } } @@ -35,7 +35,7 @@ irdl.dialect @testd { irdl.dialect @testd { irdl.operation @op { %0 = irdl.any - // expected-error@+1 {{name of result #0 must contain only letters, digits and underscores}} + // expected-error@+1 {{name of result #0 must contain only lowercase letters, digits and underscores}} irdl.results(test$test: %0) } } @@ -45,7 +45,7 @@ irdl.dialect @testd { irdl.dialect @testd { irdl.operation @op { %0 = irdl.any - // expected-error@+1 {{name of operand #0 must contain only letters, digits and underscores}} + // expected-error@+1 {{name of operand #0 must contain only lowercase letters, digits and underscores}} irdl.operands(test$test: %0) } } diff --git a/mlir/test/Dialect/IRDL/invalid_names.irdl.mlir b/mlir/test/Dialect/IRDL/invalid_names.irdl.mlir new file mode 100644 index 0000000000000..e2d2be52fc3db --- /dev/null +++ b/mlir/test/Dialect/IRDL/invalid_names.irdl.mlir @@ -0,0 +1,94 @@ +// RUN: mlir-irdl-to-cpp %s --verify-diagnostics --split-input-file +// expected-error@+1 {{name of dialect should not contain leading or double underscores}} +irdl.dialect @_no_leading_underscore { +} + +// ----- + +// expected-error@+1 {{name of dialect should not contain leading or double underscores}} +irdl.dialect @no__double__underscores { +} + +// ----- + +// expected-error@+1 {{name of dialect should not contain uppercase letters}} +irdl.dialect @NoUpperCase { +} + +// ----- + +// expected-error@+1 {{name of dialect must contain only lowercase letters, digits and underscores}} +irdl.dialect @no_weird_symbol$ { +} + +// ----- + +irdl.dialect @test_dialect { + // expected-error@+1 {{name of operation should not contain leading or double underscores}} + irdl.operation @_no_leading_underscore { + %0 = irdl.any + irdl.results(res: %0) + } +} + +// ----- + +irdl.dialect @test_dialect { + // expected-error@+1 {{name of operation should not contain leading or double underscores}} + irdl.operation @no__double__underscores { + %0 = irdl.any + irdl.results(res: %0) + } +} + +// ----- + +irdl.dialect @test_dialect { + // expected-error@+1 {{name of operation should not contain uppercase letters}} + irdl.operation @NoUpperCase { + %0 = irdl.any + irdl.results(res: %0) + } +} + +// ----- + +irdl.dialect @test_dialect { + // expected-error@+1 {{name of operation must contain only lowercase letters, digits and underscores}} + irdl.operation @no_weird_symbol$ { + %0 = irdl.any + irdl.results(res: %0) + } +} + +// ----- + +irdl.dialect @test_dialect { + irdl.operation @test_op { + %0 = irdl.any + // expected-error@+1 {{name of result #0 should not contain leading or double underscores}} + irdl.results(_no_leading_underscore: %0) + } +} + +// ----- + +irdl.dialect @test_dialect { + irdl.operation @test_op { + %0 = irdl.any + // expected-error@+1 {{name of result #0 should not contain leading or double underscores}} + irdl.results(no__double__underscores: %0) + } +} + +// ----- + +irdl.dialect @test_dialect { + irdl.operation @test_op { + %0 = irdl.any + // expected-error@+1 {{name of result #0 should not contain uppercase letters}} + irdl.results(NoUpperCase: %0) + } +} + +// ----- diff --git a/mlir/test/Dialect/IRDL/regions-ops.irdl.mlir b/mlir/test/Dialect/IRDL/regions-ops.irdl.mlir index 7c2efc6c96e41..e5f884c99e5f4 100644 --- a/mlir/test/Dialect/IRDL/regions-ops.irdl.mlir +++ b/mlir/test/Dialect/IRDL/regions-ops.irdl.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s -verify-diagnostics -split-input-file -irdl.dialect @testRegionOpNegativeNumber { +irdl.dialect @test_region_op_negative_number { irdl.operation @op { // expected-error @below {{'irdl.region' op the number of blocks is expected to be >= 1 but got -42}} %r1 = irdl.region with size -42 @@ -9,7 +9,7 @@ irdl.dialect @testRegionOpNegativeNumber { // ----- -irdl.dialect @testRegionsOpMissingName { +irdl.dialect @test_regions_op_missing_name { irdl.operation @op { %r1 = irdl.region // expected-error @below {{expected valid keyword}} @@ -19,17 +19,17 @@ irdl.dialect @testRegionsOpMissingName { // ----- -irdl.dialect @testRegionsOpWrongName { +irdl.dialect @test_regions_op_wrong_name { irdl.operation @op { %r1 = irdl.region - // expected-error @below {{name of region #0 must contain only letters, digits and underscores}} + // expected-error @below {{name of region #0 must contain only lowercase letters, digits and underscores}} irdl.regions(test$test: %r1) } } // ----- -irdl.dialect @testRegionsDuplicateName { +irdl.dialect @test_regions_duplicate_name { irdl.operation @op { %r1 = irdl.region // expected-error @below {{name of region #2 is a duplicate of the name of region #0}} @@ -39,7 +39,7 @@ irdl.dialect @testRegionsDuplicateName { // ----- -irdl.dialect @testRegionsOpWrongOperation { +irdl.dialect @test_regions_op_wrong_operation { irdl.operation @op { // expected-note @below {{prior use here}} %r1 = irdl.any diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt index a8fd70e6397a5..2a3b2dce360b0 100644 --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -19,6 +19,7 @@ add_subdirectory(SPIRV) add_subdirectory(Tensor) add_subdirectory(Test) add_subdirectory(TestDyn) +add_subdirectory(TestIRDLToCpp) add_subdirectory(Tosa) add_subdirectory(Transform) add_subdirectory(Vector) diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt b/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt new file mode 100644 index 0000000000000..98a4bd3058a2a --- /dev/null +++ b/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt @@ -0,0 +1,15 @@ +add_irdl_to_cpp_target(TestIRDLToCppGen test_irdl_to_cpp.irdl.mlir) + +add_mlir_library(MLIRTestIRDLToCppDialect + TestIRDLToCppDialect.cpp + + EXCLUDE_FROM_LIBMLIR + + DEPENDS + TestIRDLToCppGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransforms +) diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp new file mode 100644 index 0000000000000..bda614a97ab42 --- /dev/null +++ b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp @@ -0,0 +1,88 @@ +//===- TestIRDLToCppDialect.cpp - MLIR Test Dialect Types ---------------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file includes TestIRDLToCpp dialect. +// +//===----------------------------------------------------------------------===// + +// #include "mlir/IR/Dialect.h" +#include "mlir/IR/Region.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Tools/mlir-translate/Translation.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "TestIRDLToCppDialect.h" + +#define GEN_DIALECT_DEF +#include "test_irdl_to_cpp.irdl.mlir.cpp.inc" + +namespace test { +using namespace mlir; +struct TestOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::test_irdl_to_cpp::BeefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(adaptor.getStructuredOperands(0).size() == 1); + assert(adaptor.getStructuredOperands(1).size() == 1); + + auto bar = rewriter.replaceOpWithNewOp( + op, op->getResultTypes().front()); + rewriter.setInsertionPointAfter(bar); + + rewriter.create( + bar.getLoc(), rewriter.getIntegerType(32), adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +struct ConvertTestDialectToSomethingPass + : PassWrapper> { + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + ConversionTarget target(getContext()); + target.addIllegalOp(); + target.addLegalOp(); + target.addLegalOp(); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } + + StringRef getArgument() const final { return "test-irdl-conversion-check"; } + StringRef getDescription() const final { + return "Checks the convertability of an irdl dialect"; + } +}; + +void registerIrdlTestDialect(mlir::DialectRegistry ®istry) { + registry.insert(); +} + +} // namespace test + +namespace mlir::test { +void registerTestIrdlTestDialectConversionPass() { + PassRegistration<::test::ConvertTestDialectToSomethingPass>(); +} +} // namespace mlir::test diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.h b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.h new file mode 100644 index 0000000000000..5d2e3d7b4cd48 --- /dev/null +++ b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.h @@ -0,0 +1,23 @@ +//===- TestIRDLToCppDialect.h - MLIR Test Dialect Types -----------------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file includes TestIRDLToCpp dialect headers. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TEST_LIB_DIALECT_TESTIRDLTOCPP_TESTIRDLTOCPPDIALECT_H +#define MLIR_TEST_LIB_DIALECT_TESTIRDLTOCPP_TESTIRDLTOCPPDIALECT_H + +#define GEN_DIALECT_DECL_HEADER +#include "test_irdl_to_cpp.irdl.mlir.cpp.inc" + +namespace test { +void registerConvertTestDialectPass(); +} + +#endif // MLIR_TEST_LIB_DIALECT_TESTIRDLTOCPP_TESTIRDLTOCPPDIALECT_H diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test.testd.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test.testd.mlir new file mode 100644 index 0000000000000..30f27934857a5 --- /dev/null +++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test.testd.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-opt %s | FileCheck %s +// CHECK: module { +// CHECK-NEXT: [[v1:[^ ]+]] = "test_irdl_to_cpp.bar"() : () -> !test_irdl_to_cpp.foo +module { + %0 = "test_irdl_to_cpp.bar"() : () -> !test_irdl_to_cpp.foo +} diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir new file mode 100644 index 0000000000000..f6233ee18190a --- /dev/null +++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-irdl-conversion-check)" | FileCheck %s +// CHECK-LABEL: module { +module { + // CHECK: func.func @test() { + // CHECK: %[[v0:[^ ]*]] = "test_irdl_to_cpp.bar"() : () -> i32 + // CHECK: %[[v1:[^ ]*]] = "test_irdl_to_cpp.bar"() : () -> i32 + // CHECK: %[[v2:[^ ]*]] = "test_irdl_to_cpp.hash"(%[[v0]], %[[v0]]) : (i32, i32) -> i32 + // CHECK: return + // CHECK: } + func.func @test() { + %0 = "test_irdl_to_cpp.bar"() : () -> i32 + %1 = "test_irdl_to_cpp.beef"(%0, %0) : (i32, i32) -> i32 + return + } + +// CHECK: } +} diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir new file mode 100644 index 0000000000000..42e713e0adecd --- /dev/null +++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-irdl-to-cpp %s | FileCheck %s + +// CHECK: class TestIrdlToCpp +irdl.dialect @test_irdl_to_cpp { + + // CHECK: class FooType + irdl.type @foo + + // CHECK: class BarOp + // CHECK: ::mlir::Value getRes() + irdl.operation @bar { + %0 = irdl.any + irdl.results(res: %0) + } + + // CHECK: class BeefOp + // CHECK: ::mlir::Value getLhs() + // CHECK: ::mlir::Value getRhs() + // CHECK: ::mlir::Value getRes() + irdl.operation @beef { + %0 = irdl.any + irdl.operands(lhs: %0, rhs: %0) + irdl.results(res: %0) + } + + // CHECK: class HashOp + // CHECK: ::mlir::Value getLhs() + // CHECK: ::mlir::Value getRhs() + // CHECK: ::mlir::Value getRes() + irdl.operation @hash { + %0 = irdl.any + irdl.operands(lhs: %0, rhs: %0) + irdl.results(res: %0) + } +} diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir new file mode 100644 index 0000000000000..403b49235467c --- /dev/null +++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir @@ -0,0 +1,79 @@ +// RUN: mlir-irdl-to-cpp %s --verify-diagnostics --split-input-file +irdl.dialect @test_irdl_to_cpp { + irdl.operation @results_no_any_of { + %0 = irdl.any + // expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.any_of operation}} + %1 = irdl.any_of(%0, %0) + irdl.results(res: %1) + } +} +// ----- + +irdl.dialect @test_irdl_to_cpp { + irdl.operation @operands_no_any_of { + %0 = irdl.any + // expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.all_of operation}} + %1 = irdl.all_of(%0, %0) + irdl.operands(test: %1) + irdl.results(res: %0) + } +} + +// ----- + +irdl.dialect @test_irdl_to_cpp { + // expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.attribute operation}} + irdl.attribute @no_attrs +} + +// ----- + +irdl.dialect @test_irdl_to_cpp { + irdl.operation @test_op { + %0 = irdl.any + // expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.attributes operation}} + irdl.attributes { + "attr" = %0 + } + } +} + +// ----- + +irdl.dialect @test_irdl_to_cpp { + irdl.type @ty { + %0 = irdl.any + // expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.parameters operation}} + irdl.parameters(ty: %0) + } +} + +// ----- + +irdl.dialect @test_irdl_to_cpp { + irdl.operation @test_op { + // expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.region operation}} + %0 = irdl.region() + irdl.regions(reg: %0) + } + +} + +// ----- + +irdl.dialect @test_irdl_to_cpp { + irdl.operation @test_op { + // expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.regions operation}} + irdl.regions() + } + +} + +// ----- + +irdl.dialect @test_irdl_to_cpp { + irdl.type @test_derived { + // expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.base operation}} + %0 = irdl.base "!builtin.integer" + } +} diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td index 913b9b1b759e2..4ab1750e30fb4 100644 --- a/mlir/test/tblgen-to-irdl/TestDialect.td +++ b/mlir/test/tblgen-to-irdl/TestDialect.td @@ -97,7 +97,7 @@ def Test_Integers : Test_Op<"integers"> { // Check that IRDL names are properly generated when needed. def Test_NamesOp : Test_Op<"names"> { let arguments = (ins I32:$a, - I32:$_hello, + I32:$hello, I32:$unnamed0, I32); let regions = (region AnyRegion:$unnamed1); @@ -110,7 +110,7 @@ def Test_NamesOp : Test_Op<"names"> { // CHECK-NEXT: %[[v3:[^ ]*]] = irdl.is i32 // CHECK-NEXT: %[[v4:[^ ]*]] = irdl.is i32 // CHECK-NEXT: %[[reg:[^ ]*]] = irdl.region -// CHECK-NEXT: irdl.operands(a: %[[v0]], _hello: %[[v1]], unnamed0: %[[v2]], unnamed3: %[[v3]]) +// CHECK-NEXT: irdl.operands(a: %[[v0]], hello: %[[v1]], unnamed0: %[[v2]], unnamed3: %[[v3]]) // CHECK-NEXT: irdl.results(unnamed2: %[[v4]]) // CHECK-NEXT: irdl.regions(unnamed1: %[[reg]]) // CHECK-NEXT: } diff --git a/mlir/tools/mlir-irdl-to-cpp/CMakeLists.txt b/mlir/tools/mlir-irdl-to-cpp/CMakeLists.txt new file mode 100644 index 0000000000000..3a0f77d856b44 --- /dev/null +++ b/mlir/tools/mlir-irdl-to-cpp/CMakeLists.txt @@ -0,0 +1,28 @@ +add_llvm_executable(mlir-irdl-to-cpp + mlir-irdl-to-cpp.cpp + ) +mlir_target_link_libraries(mlir-irdl-to-cpp + PRIVATE + MLIRTargetIRDLToCpp + ) + +# Set up a native build when cross-compiling. +if(LLVM_USE_HOST_TOOLS) + build_native_tool( + mlir-irdl-to-cpp + MLIR_IRDL_TO_CPP_EXE + + # Native tool must depend on target tool so that the native tool gets + # properly rebuilt when the target tool changes. + DEPENDS mlir-irdl-to-cpp + ) + add_custom_target(mlir-irdl-to-cpp-host DEPENDS ${MLIR_IRDL_TO_CPP_EXE}) + set(MLIR_IRDL_TO_CPP_TARGET mlir-irdl-to-cpp-host) +else() + set(MLIR_IRDL_TO_CPP_EXE $) + set(MLIR_IRDL_TO_CPP_TARGET mlir-irdl-to-cpp) +endif() + +# Save the executable path and target name to the cache to expose it globally. +set(MLIR_IRDL_TO_CPP_EXE "${MLIR_IRDL_TO_CPP_EXE}" CACHE INTERNAL "") +set(MLIR_IRDL_TO_CPP_TARGET "${MLIR_IRDL_TO_CPP_TARGET}" CACHE INTERNAL "") diff --git a/mlir/tools/mlir-irdl-to-cpp/mlir-irdl-to-cpp.cpp b/mlir/tools/mlir-irdl-to-cpp/mlir-irdl-to-cpp.cpp new file mode 100644 index 0000000000000..a63b289ffaea2 --- /dev/null +++ b/mlir/tools/mlir-irdl-to-cpp/mlir-irdl-to-cpp.cpp @@ -0,0 +1,143 @@ +//===- mlir-irdl-to-cpp.cpp - IRDL to C++ conversion tool -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a command line utility that translates an IRDL dialect definition +// into a C++ implementation to be included in MLIR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/ToolUtilities.h" +#include "mlir/Target/IRDLToCpp/IRDLToCpp.h" +#include "mlir/Tools/ParseUtilities.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" + +using namespace mlir; + +static LogicalResult +processBuffer(llvm::raw_ostream &os, + std::unique_ptr ownedBuffer, + bool verifyDiagnostics, llvm::ThreadPoolInterface *threadPool) { + // Tell sourceMgr about this buffer, which is what the parser will pick up. + auto sourceMgr = std::make_shared(); + sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); + + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(registry); + + ctx.printOpOnDiagnostic(!verifyDiagnostics); + + auto runTranslation = [&]() { + ParserConfig parseConfig(&ctx); + OwningOpRef op = + parseSourceFileForTool(sourceMgr, parseConfig, true); + if (!op) + return failure(); + + auto moduleOp = llvm::cast(*op); + llvm::SmallVector dialects{ + moduleOp.getOps(), + }; + + return irdl::translateIRDLDialectToCpp(dialects, os); + }; + + if (!verifyDiagnostics) { + // If no errors are expected, return translation result. + SourceMgrDiagnosticHandler srcManagerHandler(*sourceMgr, &ctx); + return runTranslation(); + } + + // If errors are expected, ignore translation result and check for + // diagnostics. + SourceMgrDiagnosticVerifierHandler srcManagerHandler(*sourceMgr, &ctx); + (void)runTranslation(); + return srcManagerHandler.verify(); +} + +static LogicalResult translateIRDLToCpp(int argc, char **argv) { + static llvm::cl::opt inputFilename( + llvm::cl::Positional, llvm::cl::desc(""), + llvm::cl::init("-")); + + static llvm::cl::opt outputFilename( + "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("-")); + + static llvm::cl::opt verifyDiagnostics( + "verify-diagnostics", + llvm::cl::desc("Check that emitted diagnostics match " + "expected-* lines on the corresponding line"), + llvm::cl::init(false)); + + static llvm::cl::opt splitInputFile( + "split-input-file", llvm::cl::ValueOptional, + llvm::cl::callback([&](const std::string &str) { + // Implicit value: use default marker if flag was used without + // value. + if (str.empty()) + splitInputFile.setValue(kDefaultSplitMarker); + }), + llvm::cl::desc("Split the input file into chunks using the given or " + "default marker and process each chunk independently"), + llvm::cl::init("")); + + llvm::InitLLVM y(argc, argv); + + llvm::cl::ParseCommandLineOptions(argc, argv, "mlir-irdl-to-cpp"); + + std::string errorMessage; + std::unique_ptr input = + openInputFile(inputFilename, &errorMessage); + if (!input) { + llvm::errs() << errorMessage << "\n"; + return failure(); + } + + std::unique_ptr output = + openOutputFile(outputFilename, &errorMessage); + + if (!output) { + llvm::errs() << errorMessage << "\n"; + return failure(); + } + + auto chunkFn = [&](std::unique_ptr chunkBuffer, + raw_ostream &os) { + return processBuffer(output->os(), std::move(chunkBuffer), + verifyDiagnostics, nullptr); + }; + + auto &splitInputFileDelimiter = splitInputFile.getValue(); + if (splitInputFileDelimiter.size()) + return splitAndProcessBuffer(std::move(input), chunkFn, output->os(), + splitInputFileDelimiter, + splitInputFileDelimiter); + + if (failed(chunkFn(std::move(input), output->os()))) + return failure(); + + if (!verifyDiagnostics) + output->keep(); + + return success(); +} + +int main(int argc, char **argv) { + return failed(translateIRDLToCpp(argc, argv)); +} diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 3220dca282eac..67fea55d141c0 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -53,6 +53,7 @@ if(MLIR_INCLUDE_TESTS) set(test_libs ${test_libs} MLIRTestPDLL MLIRTestTransformDialect + MLIRTestIRDLToCppDialect ) if (MLIR_ENABLE_PDL_IN_PATTERNMATCH) diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index cdcf59b2add13..b51da0e0383d6 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -128,6 +128,7 @@ void registerTestMatchReductionPass(); void registerTestMathAlgebraicSimplificationPass(); void registerTestMathPolynomialApproximationPass(); void registerTestMathToVCIXPass(); +void registerTestIrdlTestDialectConversionPass(); void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); void registerTestMeshReshardingSpmdizationPass(); @@ -172,6 +173,7 @@ void registerTestDialect(DialectRegistry &); void registerTestDynDialect(DialectRegistry &); void registerTestTilingInterfaceTransformDialectExtension(DialectRegistry &); void registerTestTransformDialectExtension(DialectRegistry &); +void registerIrdlTestDialect(DialectRegistry &); void registerTestTransformsTransformDialectExtension(DialectRegistry &); } // namespace test @@ -250,6 +252,7 @@ void registerTestPasses() { mlir::test::registerTestIRVisitorsPass(); mlir::test::registerTestGenericIRVisitorsPass(); mlir::test::registerTestInterfaces(); + mlir::test::registerTestIrdlTestDialectConversionPass(); mlir::test::registerTestIRVisitorsPass(); mlir::test::registerTestLastModifiedPass(); mlir::test::registerTestLinalgDecomposeOps(); @@ -326,11 +329,12 @@ int main(int argc, char **argv) { registerAllGPUToLLVMIRTranslations(registry); #ifdef MLIR_INCLUDE_TESTS + ::test::registerIrdlTestDialect(registry); ::test::registerTestDialect(registry); + ::test::registerTestDynDialect(registry); + ::test::registerTestTilingInterfaceTransformDialectExtension(registry); ::test::registerTestTransformDialectExtension(registry); ::test::registerTestTransformsTransformDialectExtension(registry); - ::test::registerTestTilingInterfaceTransformDialectExtension(registry); - ::test::registerTestDynDialect(registry); #endif return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "MLIR modular optimizer driver\n", registry));