Skip to content

Commit 45c6dd1

Browse files
committed
[mlir][pybind] export more options on enable_ir_printing() api
1 parent 37b0889 commit 45c6dd1

File tree

3 files changed

+49
-6
lines changed

3 files changed

+49
-6
lines changed

mlir/include/mlir-c/Pass.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ MLIR_CAPI_EXPORTED MlirLogicalResult
7575
mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
7676

7777
/// Enable mlir-print-ir-after-all.
78-
MLIR_CAPI_EXPORTED void
79-
mlirPassManagerEnableIRPrinting(MlirPassManager passManager);
78+
MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
79+
MlirPassManager passManager, bool printBeforePass, bool printAfterPass,
80+
bool printModuleScope, bool printAfterOnlyOnChange,
81+
bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags);
8082

8183
/// Enable / disable verify-each.
8284
MLIR_CAPI_EXPORTED void

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,34 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
7373
"Releases (leaks) the backing pass manager (testing)")
7474
.def(
7575
"enable_ir_printing",
76-
[](PyPassManager &passManager) {
77-
mlirPassManagerEnableIRPrinting(passManager.get());
76+
[](PyPassManager &passManager, bool printBeforePass,
77+
bool printAfterPass, bool printModuleScope,
78+
bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure,
79+
std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
80+
bool printGenericOpForm) {
81+
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
82+
if (largeElementsLimit)
83+
mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
84+
*largeElementsLimit);
85+
if (enableDebugInfo)
86+
mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
87+
/*prettyForm=*/false);
88+
if (printGenericOpForm)
89+
mlirOpPrintingFlagsPrintGenericOpForm(flags);
90+
mlirPassManagerEnableIRPrinting(passManager.get(), printBeforePass,
91+
printAfterPass, printModuleScope,
92+
printAfterOnlyOnChange,
93+
printAfterOnlyOnFailure, flags);
94+
mlirOpPrintingFlagsDestroy(flags);
7895
},
96+
py::arg("print_before_pass") = true,
97+
py::arg("print_after_pass") = true,
98+
py::arg("print_module_scope") = true,
99+
py::arg("print_after_only_on_change") = true,
100+
py::arg("print_after_only_on_failure") = false,
101+
py::arg("large_elements_limit") = py::none(),
102+
py::arg("enable_debug_info") = false,
103+
py::arg("print_generic_op_form") = false,
79104
"Enable mlir-print-ir-after-all.")
80105
.def(
81106
"enable_verifier",

mlir/lib/CAPI/IR/Pass.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/CAPI/Support.h"
1414
#include "mlir/CAPI/Utils.h"
1515
#include "mlir/Pass/PassManager.h"
16+
#include <functional>
1617
#include <optional>
1718

1819
using namespace mlir;
@@ -44,8 +45,23 @@ MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
4445
return wrap(unwrap(passManager)->run(unwrap(op)));
4546
}
4647

47-
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
48-
return unwrap(passManager)->enableIRPrinting();
48+
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
49+
bool printBeforePass, bool printAfterPass,
50+
bool printModuleScope,
51+
bool printAfterOnlyOnChange,
52+
bool printAfterOnlyOnFailure,
53+
MlirOpPrintingFlags flags) {
54+
std::function<bool(Pass *, Operation *)> shouldPrintBeforePass = nullptr;
55+
std::function<bool(Pass *, Operation *)> shouldPrintAfterPass = nullptr;
56+
if (printBeforePass)
57+
shouldPrintBeforePass = [](Pass *, Operation *) { return true; };
58+
if (printAfterPass)
59+
shouldPrintAfterPass = [](Pass *, Operation *) { return true; };
60+
return unwrap(passManager)
61+
->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
62+
printModuleScope, printAfterOnlyOnChange,
63+
printAfterOnlyOnFailure, /*out=*/llvm::errs(),
64+
*unwrap(flags));
4965
}
5066

5167
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {

0 commit comments

Comments
 (0)