From a7d8bd9a6dd19da559d245044ee3b62b7e452516 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Wed, 27 Nov 2024 11:28:58 +0800 Subject: [PATCH 1/3] [MLIR][Python] enhance python ir printing with pringing flags --- mlir/include/mlir-c/Pass.h | 3 ++- mlir/lib/Bindings/Python/Pass.cpp | 17 +++++++++++++++-- mlir/lib/CAPI/IR/Pass.cpp | 6 ++++-- .../mlir/_mlir_libs/_mlir/passmanager.pyi | 4 ++++ 4 files changed, 25 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 6019071cfdaa2..8fd8e9956a65a 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -81,7 +81,8 @@ mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op); MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting( MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, - bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath); + bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags, + MlirStringRef treePrintingPath); /// Enable / disable verify-each. MLIR_CAPI_EXPORTED void diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index e8d28abe6d583..e991deaae2daa 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -76,20 +76,33 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { "enable_ir_printing", [](PyPassManager &passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterChange, - bool printAfterFailure, + bool printAfterFailure, std::optional largeElementsLimit, + bool enableDebugInfo, bool printGenericOpForm, std::optional optionalTreePrintingPath) { + MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); + if (largeElementsLimit) + mlirOpPrintingFlagsElideLargeElementsAttrs(flags, + *largeElementsLimit); + if (enableDebugInfo) + mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, + /*prettyForm=*/false); + if (printGenericOpForm) + mlirOpPrintingFlagsPrintGenericOpForm(flags); std::string treePrintingPath = ""; if (optionalTreePrintingPath.has_value()) treePrintingPath = optionalTreePrintingPath.value(); mlirPassManagerEnableIRPrinting( passManager.get(), printBeforeAll, printAfterAll, - printModuleScope, printAfterChange, printAfterFailure, + printModuleScope, printAfterChange, printAfterFailure, flags, mlirStringRefCreate(treePrintingPath.data(), treePrintingPath.size())); + mlirOpPrintingFlagsDestroy(flags); }, "print_before_all"_a = false, "print_after_all"_a = true, "print_module_scope"_a = false, "print_after_change"_a = false, "print_after_failure"_a = false, + "large_elements_limit"_a = py::none(), "enable_debug_info"_a = false, + "print_generic_op_form"_a = false, "tree_printing_dir_path"_a = py::none(), "Enable IR printing, default as mlir-print-ir-after-all.") .def( diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index 01151eafeb526..883b7e8bb832d 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -49,6 +49,7 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, bool printModuleScope, bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure, + MlirOpPrintingFlags flags, MlirStringRef treePrintingPath) { auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) { return printBeforeAll; @@ -60,13 +61,14 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, return unwrap(passManager) ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, printModuleScope, printAfterOnlyOnChange, - printAfterOnlyOnFailure); + printAfterOnlyOnFailure, /*out=*/llvm::errs(), + *unwrap(flags)); unwrap(passManager) ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass, printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, - unwrap(treePrintingPath)); + unwrap(treePrintingPath), *unwrap(flags)); } void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi index 229979ae33608..0d2eaffe16d3e 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -22,6 +22,10 @@ class PassManager: print_module_scope: bool = False, print_after_change: bool = False, print_after_failure: bool = False, + large_elements_limit: int | None = None, + enable_debug_info: bool = False, + print_generic_op_form: bool = False, + tree_printing_dir_path: str | None = None, ) -> None: ... def enable_verifier(self, enable: bool) -> None: ... @staticmethod From d452e81131a6db19af92b2f9c86ec07646473211 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Wed, 4 Dec 2024 15:16:38 +0800 Subject: [PATCH 2/3] add test --- mlir/test/python/pass_manager.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index a794a3fc6fa00..0e555d0dc4858 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -342,6 +342,33 @@ def testPrintIrBeforeAndAfterAll(): pm.run(module) +# CHECK-LABEL: TEST: testPrintIrLargeLimitElements +@run +def testPrintIrLargeLimitElements(): + with Context() as ctx: + module = ModuleOp.parse( + """ + module { + func.func @main() -> tensor<3xi64> { + %0 = arith.constant dense<[1, 2, 3]> : tensor<3xi64> + return %0 : tensor<3xi64> + } + } + """ + ) + pm = PassManager.parse("builtin.module(canonicalize)") + ctx.enable_multithreading(False) + pm.enable_ir_printing(large_elements_limit=2) + # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- // + # CHECK: module { + # CHECK: func.func @main() -> tensor<3xi64> { + # CHECK: %[[CST:.*]] = arith.constant dense_resource<__elided__> : tensor<3xi64> + # CHECK: return %[[CST]] : tensor<3xi64> + # CHECK: } + # CHECK: } + pm.run(module) + + # CHECK-LABEL: TEST: testPrintIrTree @run def testPrintIrTree(): From ea27fb4481be5a42fa4227359ff0682e4d814fae Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Thu, 5 Dec 2024 16:39:14 +0800 Subject: [PATCH 3/3] update test --- mlir/test/python/pass_manager.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index 0e555d0dc4858..ecac57e3302f0 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -359,13 +359,7 @@ def testPrintIrLargeLimitElements(): pm = PassManager.parse("builtin.module(canonicalize)") ctx.enable_multithreading(False) pm.enable_ir_printing(large_elements_limit=2) - # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- // - # CHECK: module { - # CHECK: func.func @main() -> tensor<3xi64> { # CHECK: %[[CST:.*]] = arith.constant dense_resource<__elided__> : tensor<3xi64> - # CHECK: return %[[CST]] : tensor<3xi64> - # CHECK: } - # CHECK: } pm.run(module)