Skip to content

Commit b78e24c

Browse files
PragmaTwicesvkeerthy
authored andcommitted
[MLIR][Python] Expose PassManager::enableStatistics to CAPI and Python (#162591)
`PassManager::enableStatistics` seems currently missing in both C API and Python bindings. So here we added them in this PR, which includes the `PassDisplayMode` enum type and the `EnableStatistics` method.
1 parent 44bf5d9 commit b78e24c

File tree

4 files changed

+62
-0
lines changed

4 files changed

+62
-0
lines changed

mlir/include/mlir-c/Pass.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,18 @@ mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable);
9292
MLIR_CAPI_EXPORTED void
9393
mlirPassManagerEnableTiming(MlirPassManager passManager);
9494

95+
/// Enumerated type of pass display modes.
96+
/// Mainly used in mlirPassManagerEnableStatistics.
97+
typedef enum {
98+
MLIR_PASS_DISPLAY_MODE_LIST,
99+
MLIR_PASS_DISPLAY_MODE_PIPELINE,
100+
} MlirPassDisplayMode;
101+
102+
/// Enable pass statistics.
103+
MLIR_CAPI_EXPORTED void
104+
mlirPassManagerEnableStatistics(MlirPassManager passManager,
105+
MlirPassDisplayMode displayMode);
106+
95107
/// Nest an OpPassManager under the top-level PassManager, the nested
96108
/// passmanager will only run on operations matching the provided name.
97109
/// The returned OpPassManager will be destroyed when the parent is destroyed.

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ class PyPassManager {
5757

5858
/// Create the `mlir.passmanager` here.
5959
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
60+
//----------------------------------------------------------------------------
61+
// Mapping of enumerated types
62+
//----------------------------------------------------------------------------
63+
nb::enum_<MlirPassDisplayMode>(m, "PassDisplayMode")
64+
.value("LIST", MLIR_PASS_DISPLAY_MODE_LIST)
65+
.value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE);
66+
6067
//----------------------------------------------------------------------------
6168
// Mapping of MlirExternalPass
6269
//----------------------------------------------------------------------------
@@ -139,6 +146,14 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
139146
mlirPassManagerEnableTiming(passManager.get());
140147
},
141148
"Enable pass timing.")
149+
.def(
150+
"enable_statistics",
151+
[](PyPassManager &passManager, MlirPassDisplayMode displayMode) {
152+
mlirPassManagerEnableStatistics(passManager.get(), displayMode);
153+
},
154+
"displayMode"_a =
155+
MlirPassDisplayMode::MLIR_PASS_DISPLAY_MODE_PIPELINE,
156+
"Enable pass statistics.")
142157
.def_static(
143158
"parse",
144159
[](const std::string &pipeline, DefaultingPyMlirContext context) {

mlir/lib/CAPI/IR/Pass.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/CAPI/Support.h"
1414
#include "mlir/CAPI/Utils.h"
1515
#include "mlir/Pass/PassManager.h"
16+
#include "llvm/Support/ErrorHandling.h"
1617
#include <optional>
1718

1819
using namespace mlir;
@@ -79,6 +80,20 @@ void mlirPassManagerEnableTiming(MlirPassManager passManager) {
7980
unwrap(passManager)->enableTiming();
8081
}
8182

83+
void mlirPassManagerEnableStatistics(MlirPassManager passManager,
84+
MlirPassDisplayMode displayMode) {
85+
PassDisplayMode mode;
86+
switch (displayMode) {
87+
case MLIR_PASS_DISPLAY_MODE_LIST:
88+
mode = PassDisplayMode::List;
89+
break;
90+
case MLIR_PASS_DISPLAY_MODE_PIPELINE:
91+
mode = PassDisplayMode::Pipeline;
92+
break;
93+
}
94+
unwrap(passManager)->enableStatistics(mode);
95+
}
96+
8297
MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
8398
MlirStringRef operationName) {
8499
return wrap(&unwrap(passManager)->nest(unwrap(operationName)));

mlir/test/python/pass_manager.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,3 +435,23 @@ def print_file_tree(directory, prefix=""):
435435

436436
print_file_tree(temp_dir)
437437
log("// Tree printing end")
438+
439+
440+
# CHECK-LABEL: TEST: testEnableStatistics
441+
@run
442+
def testEnableStatistics():
443+
with Context() as ctx:
444+
module = ModuleOp.parse(
445+
"""
446+
module {
447+
func.func @main() {
448+
%0 = arith.constant 10
449+
return
450+
}
451+
}
452+
"""
453+
)
454+
pm = PassManager.parse("builtin.module(canonicalize)")
455+
pm.enable_statistics()
456+
# CHECK: Pass statistics report
457+
pm.run(module)

0 commit comments

Comments
 (0)