diff --git a/mlir/docs/Tools/mlir-rewrite.md b/mlir/docs/Tools/mlir-rewrite.md new file mode 100644 index 0000000000000..178f92f72cbb6 --- /dev/null +++ b/mlir/docs/Tools/mlir-rewrite.md @@ -0,0 +1,29 @@ +# mlir-rewrite + +Tool to simplify rewriting .mlir files. There are a couple of build in rewrites +discussed below along with usage. + +Note: This is still in very early stage. Its so early its less a tool than a +growing collection of useful functions: to use its best to do what's needed on +a brance by just hacking it (dialects registered, rewrites etc) to say help +ease a rename, upstream useful utility functions, point to ease others +migrating, and then bin eventually. Once there are actually useful parts it +should be refactored same as mlir-opt. + +[TOC] + +## simple-rename + +Rename per op given a substring to a target. The match and replace uses LLVM's +regex sub for the match and replace while the op-name is matched via regular +string comparison. E.g., + +``` +mlir-rewrite input.mlir -o output.mlir --simple-rename \ + --simple-rename-op-name="test.concat" --simple-rename-match="axis" \ + --simple-rename-replace="bxis" +``` + +to replace `axis` substring in the text of the range corresponding to +`test.concat` ops with `bxis`. + diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index 4d2d738b734ec..361981605a76b 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -115,6 +115,7 @@ set(MLIR_TEST_DEPENDS mlir-opt mlir-query mlir-reduce + mlir-rewrite mlir-tblgen mlir-translate tblgen-lsp-server diff --git a/mlir/test/mlir-rewrite/simple.mlir b/mlir/test/mlir-rewrite/simple.mlir new file mode 100644 index 0000000000000..ab6bfe24fccf0 --- /dev/null +++ b/mlir/test/mlir-rewrite/simple.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s | mlir-rewrite --simple-rename --simple-rename-op-name="test.concat" --simple-rename-match="axis" --simple-rename-replace="bxis" | FileCheck %s -check-prefix=RENAME +// RUN: mlir-opt %s | mlir-rewrite --mark-ranges | FileCheck %s -check-prefix=RANGE +// Note: running through mlir-opt to just strip out comments & avoid self matches. + +func.func @two_dynamic_one_direct_shape(%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> tensor { + // RENAME: "test.concat"({{.*}}) {bxis = 0 : i64} + // RANGE: 《%{{.*}} = 〖"test.concat"〗({{.*}}) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor》 + %5 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor + return %5 : tensor +} + diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt index 9b474385fdae1..0a2d0ff291509 100644 --- a/mlir/tools/CMakeLists.txt +++ b/mlir/tools/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(mlir-parser-fuzzer) add_subdirectory(mlir-pdll-lsp-server) add_subdirectory(mlir-query) add_subdirectory(mlir-reduce) +add_subdirectory(mlir-rewrite) add_subdirectory(mlir-shlib) add_subdirectory(mlir-spirv-cpu-runner) add_subdirectory(mlir-translate) diff --git a/mlir/tools/mlir-rewrite/CMakeLists.txt b/mlir/tools/mlir-rewrite/CMakeLists.txt new file mode 100644 index 0000000000000..5b8c1cd455399 --- /dev/null +++ b/mlir/tools/mlir-rewrite/CMakeLists.txt @@ -0,0 +1,35 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +set(LLVM_LINK_COMPONENTS + Support + ) + +set(LIBS + ${dialect_libs} + ${test_libs} + + MLIRAffineAnalysis + MLIRAnalysis + MLIRCastInterfaces + MLIRDialect + MLIRParser + MLIRPass + MLIRTransforms + MLIRTransformUtils + MLIRSupport + MLIRIR + ) + +include_directories(../../../clang/include) + +add_mlir_tool(mlir-rewrite + mlir-rewrite.cpp + + DEPENDS + ${LIBS} + SUPPORT_PLUGINS + ) +target_link_libraries(mlir-rewrite PRIVATE ${LIBS}) +llvm_update_compile_flags(mlir-rewrite) + +mlir_check_all_link_libraries(mlir-rewrite) +export_executable_symbols_for_plugins(mlir-rewrite) diff --git a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp new file mode 100644 index 0000000000000..308e6490726c8 --- /dev/null +++ b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp @@ -0,0 +1,392 @@ +//===- mlir-rewrite.cpp - MLIR Rewrite Driver -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Main entry function for mlir-rewrite. +// +//===----------------------------------------------------------------------===// + +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/AsmParser/AsmParserState.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Tools/ParseUtilities.h" +#include "llvm/ADT/RewriteBuffer.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/LineIterator.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" + +using namespace mlir; + +namespace mlir { +using OperationDefinition = AsmParserState::OperationDefinition; + +/// Return the source code associated with the OperationDefinition. +SMRange getOpRange(const OperationDefinition &op) { + const char *startOp = op.scopeLoc.Start.getPointer(); + const char *endOp = op.scopeLoc.End.getPointer(); + + for (auto res : op.resultGroups) { + SMRange range = res.definition.loc; + startOp = std::min(startOp, range.Start.getPointer()); + } + return {SMLoc::getFromPointer(startOp), SMLoc::getFromPointer(endOp)}; +} + +/// Helper to simplify rewriting the source file. +class RewritePad { +public: + static std::unique_ptr init(StringRef inputFilename, + StringRef outputFilename); + + /// Return the context the file was parsed into. + MLIRContext *getContext() { return &context; } + + /// Return the OperationDefinition's of the operation's parsed. + iterator_range getOpDefs() { + return asmState.getOpDefs(); + } + + /// Insert the specified string at the specified location in the original + /// buffer. + void insertText(SMLoc pos, StringRef str, bool insertAfter = true) { + rewriteBuffer.InsertText(pos.getPointer() - start, str, insertAfter); + } + + /// Replace the range of the source text with the corresponding string in the + /// output. + void replaceRange(SMRange range, StringRef str) { + rewriteBuffer.ReplaceText(range.Start.getPointer() - start, + range.End.getPointer() - range.Start.getPointer(), + str); + } + + /// Replace the range of the operation in the source text with the + /// corresponding string in the output. + void replaceDef(const OperationDefinition &opDef, StringRef newDef) { + replaceRange(getOpRange(opDef), newDef); + } + + /// Return the source string corresponding to the source range. + StringRef getSourceString(SMRange range) { + return StringRef(range.Start.getPointer(), + range.End.getPointer() - range.Start.getPointer()); + } + + /// Return the source string corresponding to operation definition. + StringRef getSourceString(const OperationDefinition &opDef) { + auto range = getOpRange(opDef); + return getSourceString(range); + } + + /// Write to stream the result of applying all changes to the + /// original buffer. + /// Note that it isn't safe to use this function to overwrite memory mapped + /// files in-place (PR17960). + /// + /// The original buffer is not actually changed. + raw_ostream &write(raw_ostream &stream) const { + return rewriteBuffer.write(stream); + } + + /// Return lines that are purely comments. + SmallVector getSingleLineComments() { + unsigned curBuf = sourceMgr.getMainFileID(); + const llvm::MemoryBuffer *curMB = sourceMgr.getMemoryBuffer(curBuf); + llvm::line_iterator lineIterator(*curMB); + SmallVector ret; + for (; !lineIterator.is_at_end(); ++lineIterator) { + StringRef trimmed = lineIterator->ltrim(); + if (trimmed.starts_with("//")) { + ret.emplace_back( + SMLoc::getFromPointer(trimmed.data()), + SMLoc::getFromPointer(trimmed.data() + trimmed.size())); + } + } + return ret; + } + + /// Return the IR from parsed file. + Block *getParsed() { return &parsedIR; } + + /// Return the definition for the given operation, or nullptr if the given + /// operation does not have a definition. + const OperationDefinition &getOpDef(Operation *op) const { + return *asmState.getOpDef(op); + } + +private: + // The context and state required to parse. + MLIRContext context; + llvm::SourceMgr sourceMgr; + DialectRegistry registry; + FallbackAsmResourceMap fallbackResourceMap; + + // Storage of textual parsing results. + AsmParserState asmState; + + // Parsed IR. + Block parsedIR; + + // The RewriteBuffer is doing most of the real work. + llvm::RewriteBuffer rewriteBuffer; + + // Start of the original input, used to compute offset. + const char *start; +}; + +std::unique_ptr RewritePad::init(StringRef inputFilename, + StringRef outputFilename) { + std::unique_ptr r = std::make_unique(); + + // Register all the dialects needed. + registerAllDialects(r->registry); + + // Set up the input file. + std::string errorMessage; + std::unique_ptr file = + openInputFile(inputFilename, &errorMessage); + if (!file) { + llvm::errs() << errorMessage << "\n"; + return nullptr; + } + r->sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); + + // Set up the MLIR context and error handling. + r->context.appendDialectRegistry(r->registry); + + // Record the start of the buffer to compute offsets with. + unsigned curBuf = r->sourceMgr.getMainFileID(); + const llvm::MemoryBuffer *curMB = r->sourceMgr.getMemoryBuffer(curBuf); + r->start = curMB->getBufferStart(); + r->rewriteBuffer.Initialize(curMB->getBuffer()); + + // Parse and populate the AsmParserState. + ParserConfig parseConfig(&r->context, /*verifyAfterParse=*/true, + &r->fallbackResourceMap); + // Always allow unregistered. + r->context.allowUnregisteredDialects(true); + if (failed(parseAsmSourceFile(r->sourceMgr, &r->parsedIR, parseConfig, + &r->asmState))) + return nullptr; + + return r; +} + +/// Return the source code associated with the operation name. +SMRange getOpNameRange(const OperationDefinition &op) { return op.loc; } + +/// Return whether the operation was printed using generic syntax in original +/// buffer. +bool isGeneric(const OperationDefinition &op) { + return op.loc.Start.getPointer()[0] == '"'; +} + +inline int asMainReturnCode(LogicalResult r) { + return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE; +} + +/// Reriter function to invoke. +using RewriterFunction = std::function; + +/// Structure to group information about a rewriter (argument to invoke via +/// mlir-tblgen, description, and rewriter function). +class RewriterInfo { +public: + /// RewriterInfo constructor should not be invoked directly, instead use + /// RewriterRegistration or registerRewriter. + RewriterInfo(StringRef arg, StringRef description, RewriterFunction rewriter) + : arg(arg), description(description), rewriter(std::move(rewriter)) {} + + /// Invokes the rewriter and returns whether the rewriter failed. + LogicalResult invoke(mlir::RewritePad &rewriteState, raw_ostream &os) const { + assert(rewriter && "Cannot call rewriter with null rewriter"); + return rewriter(rewriteState, os); + } + + /// Returns the command line option that may be passed to 'mlir-rewrite' to + /// invoke this rewriter. + StringRef getRewriterArgument() const { return arg; } + + /// Returns a description for the rewriter. + StringRef getRewriterDescription() const { return description; } + +private: + // The argument with which to invoke the rewriter via mlir-tblgen. + StringRef arg; + + // Description of the rewriter. + StringRef description; + + // Rewritererator function. + RewriterFunction rewriter; +}; + +static llvm::ManagedStatic> rewriterRegistry; + +/// Adds command line option for each registered rewriter. +struct RewriterNameParser : public llvm::cl::parser { + RewriterNameParser(llvm::cl::Option &opt); + + void printOptionInfo(const llvm::cl::Option &o, + size_t globalWidth) const override; +}; + +/// RewriterRegistration provides a global initializer that registers a rewriter +/// function. +struct RewriterRegistration { + RewriterRegistration(StringRef arg, StringRef description, + const RewriterFunction &function); +}; + +RewriterRegistration::RewriterRegistration(StringRef arg, StringRef description, + const RewriterFunction &function) { + rewriterRegistry->emplace_back(arg, description, function); +} + +RewriterNameParser::RewriterNameParser(llvm::cl::Option &opt) + : llvm::cl::parser(opt) { + for (const auto &kv : *rewriterRegistry) { + addLiteralOption(kv.getRewriterArgument(), &kv, + kv.getRewriterDescription()); + } +} + +void RewriterNameParser::printOptionInfo(const llvm::cl::Option &o, + size_t globalWidth) const { + RewriterNameParser *tp = const_cast(this); + llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(), + [](const RewriterNameParser::OptionInfo *vT1, + const RewriterNameParser::OptionInfo *vT2) { + return vT1->Name.compare(vT2->Name); + }); + using llvm::cl::parser; + parser::printOptionInfo(o, globalWidth); +} + +} // namespace mlir + +// TODO: Make these injectable too in non-global way. +static llvm::cl::OptionCategory clSimpleRenameCategory{"simple-rename options"}; +static llvm::cl::opt simpleRenameOpName{ + "simple-rename-op-name", llvm::cl::desc("Name of op to match on"), + llvm::cl::cat(clSimpleRenameCategory)}; +static llvm::cl::opt simpleRenameMatch{ + "simple-rename-match", llvm::cl::desc("Match string for rename"), + llvm::cl::cat(clSimpleRenameCategory)}; +static llvm::cl::opt simpleRenameReplace{ + "simple-rename-replace", llvm::cl::desc("Replace string for rename"), + llvm::cl::cat(clSimpleRenameCategory)}; + +// Rewriter that does simple renames. +LogicalResult simpleRename(RewritePad &rewriteState, raw_ostream &os) { + StringRef opName = simpleRenameOpName; + StringRef match = simpleRenameMatch; + StringRef replace = simpleRenameReplace; + llvm::Regex regex(match); + + rewriteState.getParsed()->walk([&](Operation *op) { + if (op->getName().getStringRef() != opName) + return; + + const OperationDefinition &opDef = rewriteState.getOpDef(op); + SMRange range = getOpRange(opDef); + // This is a little bit overkill for simple. + std::string str = regex.sub(replace, rewriteState.getSourceString(range)); + rewriteState.replaceRange(range, str); + }); + return success(); +} + +static mlir::RewriterRegistration rewriteSimpleRename("simple-rename", + "Perform a simple rename", + simpleRename); + +// Rewriter that insert range markers. +LogicalResult markRanges(RewritePad &rewriteState, raw_ostream &os) { + for (auto it : rewriteState.getOpDefs()) { + auto [startOp, endOp] = getOpRange(it); + + rewriteState.insertText(startOp, "《"); + rewriteState.insertText(endOp, "》"); + + auto nameRange = getOpNameRange(it); + + if (isGeneric(it)) { + rewriteState.insertText(nameRange.Start, "〖"); + rewriteState.insertText(nameRange.End, "〗"); + } else { + rewriteState.insertText(nameRange.Start, "〔"); + rewriteState.insertText(nameRange.End, "〕"); + } + } + + // Highlight all comment lines. + // TODO: Could be replaced if this is kept in memory. + for (auto commentLine : rewriteState.getSingleLineComments()) { + rewriteState.insertText(commentLine.Start, "❰"); + rewriteState.insertText(commentLine.End, "❱"); + } + + return success(); +} + +static mlir::RewriterRegistration + rewriteMarkRanges("mark-ranges", "Indicate ranges parsed", markRanges); + +int main(int argc, char **argv) { + static llvm::cl::opt inputFilename( + llvm::cl::Positional, llvm::cl::desc(""), + llvm::cl::init("-")); + + static llvm::cl::opt outputFilename( + "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("-")); + + llvm::cl::opt + rewriter("", llvm::cl::desc("Rewriter to run")); + + std::string helpHeader = "mlir-rewrite"; + + llvm::cl::ParseCommandLineOptions(argc, argv, helpHeader); + + // If no rewriter has been selected, exit with error code. Could also just + // return but its unlikely this was intentionally being used as `cp`. + if (!rewriter) { + llvm::errs() << "No rewriter selected!\n"; + return mlir::asMainReturnCode(mlir::failure()); + } + + // Set up rewrite buffer. + auto rewriterOr = RewritePad::init(inputFilename, outputFilename); + if (!rewriterOr) + return mlir::asMainReturnCode(mlir::failure()); + + // Set up the output file. + std::string errorMessage; + auto output = openOutputFile(outputFilename, &errorMessage); + if (!output) { + llvm::errs() << errorMessage << "\n"; + return mlir::asMainReturnCode(mlir::failure()); + } + + LogicalResult result = rewriter->invoke(*rewriterOr, output->os()); + if (succeeded(result)) { + rewriterOr->write(output->os()); + output->keep(); + } + return mlir::asMainReturnCode(result); +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index e52439f00879f..2e520fca978e2 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9740,6 +9740,28 @@ cc_binary( ], ) +cc_binary( + name = "mlir-rewrite", + srcs = ["tools/mlir-rewrite/mlir-rewrite.cpp"], + deps = [ + ":AllExtensions", + ":AllPassesAndDialects", + ":AffineAnalysis", + ":Analysis", + ":AsmParser", + ":CastInterfaces", + ":Dialect", + ":Parser", + ":ParseUtilities", + ":Pass", + ":Transforms", + ":TransformUtils", + ":Support", + ":IR", + "//llvm:Support", + ] +) + cc_library( name = "MlirJitRunner", srcs = ["lib/ExecutionEngine/JitRunner.cpp"], @@ -10603,6 +10625,7 @@ cc_library( ":AtomicInterfaces", ":AtomicInterfacesIncGen", ":ControlFlowInterfaces", + ":ConvertToLLVMInterface", ":FuncDialect", ":IR", ":LLVMDialect", @@ -10650,6 +10673,7 @@ cc_library( ":ArithToLLVM", ":ControlFlowToLLVM", ":ConversionPassIncGen", + ":ConvertToLLVMInterface", ":FuncToLLVM", ":LLVMCommonConversion", ":LLVMDialect",