Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir-c/Pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
MlirPassManager passManager, bool printBeforeAll, bool printAfterAll,
bool printModuleScope, bool printAfterOnlyOnChange,
bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath);
bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags,
MlirStringRef treePrintingPath);

/// Enable / disable verify-each.
MLIR_CAPI_EXPORTED void
Expand Down
17 changes: 15 additions & 2 deletions mlir/lib/Bindings/Python/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,33 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
"enable_ir_printing",
[](PyPassManager &passManager, bool printBeforeAll,
bool printAfterAll, bool printModuleScope, bool printAfterChange,
bool printAfterFailure,
bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool printGenericOpForm,
std::optional<std::string> optionalTreePrintingPath) {
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
if (largeElementsLimit)
mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
*largeElementsLimit);
if (enableDebugInfo)
mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
/*prettyForm=*/false);
if (printGenericOpForm)
mlirOpPrintingFlagsPrintGenericOpForm(flags);
std::string treePrintingPath = "";
if (optionalTreePrintingPath.has_value())
treePrintingPath = optionalTreePrintingPath.value();
mlirPassManagerEnableIRPrinting(
passManager.get(), printBeforeAll, printAfterAll,
printModuleScope, printAfterChange, printAfterFailure,
printModuleScope, printAfterChange, printAfterFailure, flags,
mlirStringRefCreate(treePrintingPath.data(),
treePrintingPath.size()));
mlirOpPrintingFlagsDestroy(flags);
},
"print_before_all"_a = false, "print_after_all"_a = true,
"print_module_scope"_a = false, "print_after_change"_a = false,
"print_after_failure"_a = false,
"large_elements_limit"_a = py::none(), "enable_debug_info"_a = false,
"print_generic_op_form"_a = false,
"tree_printing_dir_path"_a = py::none(),
"Enable IR printing, default as mlir-print-ir-after-all.")
.def(
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/CAPI/IR/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
bool printModuleScope,
bool printAfterOnlyOnChange,
bool printAfterOnlyOnFailure,
MlirOpPrintingFlags flags,
MlirStringRef treePrintingPath) {
auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
return printBeforeAll;
Expand All @@ -60,13 +61,14 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
return unwrap(passManager)
->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
printModuleScope, printAfterOnlyOnChange,
printAfterOnlyOnFailure);
printAfterOnlyOnFailure, /*out=*/llvm::errs(),
*unwrap(flags));

unwrap(passManager)
->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
printModuleScope, printAfterOnlyOnChange,
printAfterOnlyOnFailure,
unwrap(treePrintingPath));
unwrap(treePrintingPath), *unwrap(flags));
}

void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
Expand Down
4 changes: 4 additions & 0 deletions mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class PassManager:
print_module_scope: bool = False,
print_after_change: bool = False,
print_after_failure: bool = False,
large_elements_limit: int | None = None,
enable_debug_info: bool = False,
print_generic_op_form: bool = False,
tree_printing_dir_path: str | None = None,
) -> None: ...
def enable_verifier(self, enable: bool) -> None: ...
@staticmethod
Expand Down
27 changes: 27 additions & 0 deletions mlir/test/python/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,33 @@ def testPrintIrBeforeAndAfterAll():
pm.run(module)


# CHECK-LABEL: TEST: testPrintIrLargeLimitElements
@run
def testPrintIrLargeLimitElements():
with Context() as ctx:
module = ModuleOp.parse(
"""
module {
func.func @main() -> tensor<3xi64> {
%0 = arith.constant dense<[1, 2, 3]> : tensor<3xi64>
return %0 : tensor<3xi64>
}
}
"""
)
pm = PassManager.parse("builtin.module(canonicalize)")
ctx.enable_multithreading(False)
pm.enable_ir_printing(large_elements_limit=2)
# CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- //
# CHECK: module {
# CHECK: func.func @main() -> tensor<3xi64> {
# CHECK: %[[CST:.*]] = arith.constant dense_resource<__elided__> : tensor<3xi64>
# CHECK: return %[[CST]] : tensor<3xi64>
# CHECK: }
# CHECK: }
pm.run(module)


# CHECK-LABEL: TEST: testPrintIrTree
@run
def testPrintIrTree():
Expand Down
Loading