diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 2218ec0f47d19..6019071cfdaa2 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -75,10 +75,13 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op); /// Enable IR printing. +/// The treePrintingPath argument is an optional path to a directory +/// where the dumps will be produced. If it isn't provided then dumps +/// are produced to stderr. MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting( MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, - bool printAfterOnlyOnFailure); + bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath); /// Enable / disable verify-each. MLIR_CAPI_EXPORTED void diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 1d0e5ce2115a0..e8d28abe6d583 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -76,14 +76,21 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { "enable_ir_printing", [](PyPassManager &passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterChange, - bool printAfterFailure) { + bool printAfterFailure, + std::optional optionalTreePrintingPath) { + std::string treePrintingPath = ""; + if (optionalTreePrintingPath.has_value()) + treePrintingPath = optionalTreePrintingPath.value(); mlirPassManagerEnableIRPrinting( passManager.get(), printBeforeAll, printAfterAll, - printModuleScope, printAfterChange, printAfterFailure); + printModuleScope, printAfterChange, printAfterFailure, + mlirStringRefCreate(treePrintingPath.data(), + treePrintingPath.size())); }, "print_before_all"_a = false, "print_after_all"_a = true, "print_module_scope"_a = false, "print_after_change"_a = false, "print_after_failure"_a = false, + "tree_printing_dir_path"_a = py::none(), "Enable IR printing, default as mlir-print-ir-after-all.") .def( "enable_verifier", diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index a6c9fbd08d45a..01151eafeb526 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -48,17 +48,25 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, - bool printAfterOnlyOnFailure) { + bool printAfterOnlyOnFailure, + MlirStringRef treePrintingPath) { auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) { return printBeforeAll; }; auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) { return printAfterAll; }; - return unwrap(passManager) - ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, - printModuleScope, printAfterOnlyOnChange, - printAfterOnlyOnFailure); + if (unwrap(treePrintingPath).empty()) + return unwrap(passManager) + ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, + printModuleScope, printAfterOnlyOnChange, + printAfterOnlyOnFailure); + + unwrap(passManager) + ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass, + printModuleScope, printAfterOnlyOnChange, + printAfterOnlyOnFailure, + unwrap(treePrintingPath)); } void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index 7496703256235..a794a3fc6fa00 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -1,6 +1,6 @@ # RUN: %PYTHON %s 2>&1 | FileCheck %s -import gc, sys +import gc, os, sys, tempfile from mlir.ir import * from mlir.passmanager import * from mlir.dialects.func import FuncOp @@ -340,3 +340,45 @@ def testPrintIrBeforeAndAfterAll(): # CHECK: } # CHECK: } pm.run(module) + + +# CHECK-LABEL: TEST: testPrintIrTree +@run +def testPrintIrTree(): + with Context() as ctx: + module = ModuleOp.parse( + """ + module { + func.func @main() { + %0 = arith.constant 10 + return + } + } + """ + ) + pm = PassManager.parse("builtin.module(canonicalize)") + ctx.enable_multithreading(False) + pm.enable_ir_printing() + # CHECK-LABEL: // Tree printing begin + # CHECK: \-- builtin_module_no-symbol-name + # CHECK: \-- 0_canonicalize.mlir + # CHECK-LABEL: // Tree printing end + pm.run(module) + log("// Tree printing begin") + with tempfile.TemporaryDirectory() as temp_dir: + pm.enable_ir_printing(tree_printing_dir_path=temp_dir) + pm.run(module) + + def print_file_tree(directory, prefix=""): + entries = sorted(os.listdir(directory)) + for i, entry in enumerate(entries): + path = os.path.join(directory, entry) + connector = "\-- " if i == len(entries) - 1 else "|-- " + log(f"{prefix}{connector}{entry}") + if os.path.isdir(path): + print_file_tree( + path, prefix + (" " if i == len(entries) - 1 else "│ ") + ) + + print_file_tree(temp_dir) + log("// Tree printing end")