Skip to content

Commit 2e51e15

Browse files
authored
[MLIR][Python] enhance python ir printing with pringing flags (#117836)
Close #65854
1 parent a2acb2f commit 2e51e15

File tree

5 files changed

+46
-5
lines changed

5 files changed

+46
-5
lines changed

mlir/include/mlir-c/Pass.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
8181
MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
8282
MlirPassManager passManager, bool printBeforeAll, bool printAfterAll,
8383
bool printModuleScope, bool printAfterOnlyOnChange,
84-
bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath);
84+
bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags,
85+
MlirStringRef treePrintingPath);
8586

8687
/// Enable / disable verify-each.
8788
MLIR_CAPI_EXPORTED void

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,33 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
7676
"enable_ir_printing",
7777
[](PyPassManager &passManager, bool printBeforeAll,
7878
bool printAfterAll, bool printModuleScope, bool printAfterChange,
79-
bool printAfterFailure,
79+
bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
80+
bool enableDebugInfo, bool printGenericOpForm,
8081
std::optional<std::string> optionalTreePrintingPath) {
82+
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
83+
if (largeElementsLimit)
84+
mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
85+
*largeElementsLimit);
86+
if (enableDebugInfo)
87+
mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
88+
/*prettyForm=*/false);
89+
if (printGenericOpForm)
90+
mlirOpPrintingFlagsPrintGenericOpForm(flags);
8191
std::string treePrintingPath = "";
8292
if (optionalTreePrintingPath.has_value())
8393
treePrintingPath = optionalTreePrintingPath.value();
8494
mlirPassManagerEnableIRPrinting(
8595
passManager.get(), printBeforeAll, printAfterAll,
86-
printModuleScope, printAfterChange, printAfterFailure,
96+
printModuleScope, printAfterChange, printAfterFailure, flags,
8797
mlirStringRefCreate(treePrintingPath.data(),
8898
treePrintingPath.size()));
99+
mlirOpPrintingFlagsDestroy(flags);
89100
},
90101
"print_before_all"_a = false, "print_after_all"_a = true,
91102
"print_module_scope"_a = false, "print_after_change"_a = false,
92103
"print_after_failure"_a = false,
104+
"large_elements_limit"_a = py::none(), "enable_debug_info"_a = false,
105+
"print_generic_op_form"_a = false,
93106
"tree_printing_dir_path"_a = py::none(),
94107
"Enable IR printing, default as mlir-print-ir-after-all.")
95108
.def(

mlir/lib/CAPI/IR/Pass.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
4949
bool printModuleScope,
5050
bool printAfterOnlyOnChange,
5151
bool printAfterOnlyOnFailure,
52+
MlirOpPrintingFlags flags,
5253
MlirStringRef treePrintingPath) {
5354
auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
5455
return printBeforeAll;
@@ -60,13 +61,14 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
6061
return unwrap(passManager)
6162
->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
6263
printModuleScope, printAfterOnlyOnChange,
63-
printAfterOnlyOnFailure);
64+
printAfterOnlyOnFailure, /*out=*/llvm::errs(),
65+
*unwrap(flags));
6466

6567
unwrap(passManager)
6668
->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
6769
printModuleScope, printAfterOnlyOnChange,
6870
printAfterOnlyOnFailure,
69-
unwrap(treePrintingPath));
71+
unwrap(treePrintingPath), *unwrap(flags));
7072
}
7173

7274
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {

mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ class PassManager:
2222
print_module_scope: bool = False,
2323
print_after_change: bool = False,
2424
print_after_failure: bool = False,
25+
large_elements_limit: int | None = None,
26+
enable_debug_info: bool = False,
27+
print_generic_op_form: bool = False,
28+
tree_printing_dir_path: str | None = None,
2529
) -> None: ...
2630
def enable_verifier(self, enable: bool) -> None: ...
2731
@staticmethod

mlir/test/python/pass_manager.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,27 @@ def testPrintIrBeforeAndAfterAll():
342342
pm.run(module)
343343

344344

345+
# CHECK-LABEL: TEST: testPrintIrLargeLimitElements
346+
@run
347+
def testPrintIrLargeLimitElements():
348+
with Context() as ctx:
349+
module = ModuleOp.parse(
350+
"""
351+
module {
352+
func.func @main() -> tensor<3xi64> {
353+
%0 = arith.constant dense<[1, 2, 3]> : tensor<3xi64>
354+
return %0 : tensor<3xi64>
355+
}
356+
}
357+
"""
358+
)
359+
pm = PassManager.parse("builtin.module(canonicalize)")
360+
ctx.enable_multithreading(False)
361+
pm.enable_ir_printing(large_elements_limit=2)
362+
# CHECK: %[[CST:.*]] = arith.constant dense_resource<__elided__> : tensor<3xi64>
363+
pm.run(module)
364+
365+
345366
# CHECK-LABEL: TEST: testPrintIrTree
346367
@run
347368
def testPrintIrTree():

0 commit comments

Comments
 (0)