diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 6019071cfdaa2..8fd8e9956a65a 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -81,7 +81,8 @@ mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op); MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting( MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, - bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath); + bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags, + MlirStringRef treePrintingPath); /// Enable / disable verify-each. MLIR_CAPI_EXPORTED void diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index e8d28abe6d583..e991deaae2daa 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -76,20 +76,33 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { "enable_ir_printing", [](PyPassManager &passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterChange, - bool printAfterFailure, + bool printAfterFailure, std::optional largeElementsLimit, + bool enableDebugInfo, bool printGenericOpForm, std::optional optionalTreePrintingPath) { + MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); + if (largeElementsLimit) + mlirOpPrintingFlagsElideLargeElementsAttrs(flags, + *largeElementsLimit); + if (enableDebugInfo) + mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, + /*prettyForm=*/false); + if (printGenericOpForm) + mlirOpPrintingFlagsPrintGenericOpForm(flags); std::string treePrintingPath = ""; if (optionalTreePrintingPath.has_value()) treePrintingPath = optionalTreePrintingPath.value(); mlirPassManagerEnableIRPrinting( passManager.get(), printBeforeAll, printAfterAll, - printModuleScope, printAfterChange, printAfterFailure, + printModuleScope, printAfterChange, printAfterFailure, flags, mlirStringRefCreate(treePrintingPath.data(), treePrintingPath.size())); + mlirOpPrintingFlagsDestroy(flags); }, "print_before_all"_a = false, "print_after_all"_a = true, "print_module_scope"_a = false, "print_after_change"_a = false, "print_after_failure"_a = false, + "large_elements_limit"_a = py::none(), "enable_debug_info"_a = false, + "print_generic_op_form"_a = false, "tree_printing_dir_path"_a = py::none(), "Enable IR printing, default as mlir-print-ir-after-all.") .def( diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index 01151eafeb526..883b7e8bb832d 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -49,6 +49,7 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, bool printModuleScope, bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure, + MlirOpPrintingFlags flags, MlirStringRef treePrintingPath) { auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) { return printBeforeAll; @@ -60,13 +61,14 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, return unwrap(passManager) ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, printModuleScope, printAfterOnlyOnChange, - printAfterOnlyOnFailure); + printAfterOnlyOnFailure, /*out=*/llvm::errs(), + *unwrap(flags)); unwrap(passManager) ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass, printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, - unwrap(treePrintingPath)); + unwrap(treePrintingPath), *unwrap(flags)); } void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi index 229979ae33608..0d2eaffe16d3e 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -22,6 +22,10 @@ class PassManager: print_module_scope: bool = False, print_after_change: bool = False, print_after_failure: bool = False, + large_elements_limit: int | None = None, + enable_debug_info: bool = False, + print_generic_op_form: bool = False, + tree_printing_dir_path: str | None = None, ) -> None: ... def enable_verifier(self, enable: bool) -> None: ... @staticmethod diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index a794a3fc6fa00..ecac57e3302f0 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -342,6 +342,27 @@ def testPrintIrBeforeAndAfterAll(): pm.run(module) +# CHECK-LABEL: TEST: testPrintIrLargeLimitElements +@run +def testPrintIrLargeLimitElements(): + with Context() as ctx: + module = ModuleOp.parse( + """ + module { + func.func @main() -> tensor<3xi64> { + %0 = arith.constant dense<[1, 2, 3]> : tensor<3xi64> + return %0 : tensor<3xi64> + } + } + """ + ) + pm = PassManager.parse("builtin.module(canonicalize)") + ctx.enable_multithreading(False) + pm.enable_ir_printing(large_elements_limit=2) + # CHECK: %[[CST:.*]] = arith.constant dense_resource<__elided__> : tensor<3xi64> + pm.run(module) + + # CHECK-LABEL: TEST: testPrintIrTree @run def testPrintIrTree():