diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 0d2e19ee7fb0a..1f63c6d0dcab8 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -92,6 +92,18 @@ mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable); MLIR_CAPI_EXPORTED void mlirPassManagerEnableTiming(MlirPassManager passManager); +/// Enumerated type of pass display modes. +/// Mainly used in mlirPassManagerEnableStatistics. +typedef enum { + MLIR_PASS_DISPLAY_MODE_LIST, + MLIR_PASS_DISPLAY_MODE_PIPELINE, +} MlirPassDisplayMode; + +/// Enable pass statistics. +MLIR_CAPI_EXPORTED void +mlirPassManagerEnableStatistics(MlirPassManager passManager, + MlirPassDisplayMode displayMode); + /// Nest an OpPassManager under the top-level PassManager, the nested /// passmanager will only run on operations matching the provided name. /// The returned OpPassManager will be destroyed when the parent is destroyed. diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index e489585fd5f50..f55f827f48c09 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -56,6 +56,13 @@ class PyPassManager { /// Create the `mlir.passmanager` here. void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { + //---------------------------------------------------------------------------- + // Mapping of enumerated types + //---------------------------------------------------------------------------- + nb::enum_(m, "PassDisplayMode") + .value("LIST", MLIR_PASS_DISPLAY_MODE_LIST) + .value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE); + //---------------------------------------------------------------------------- // Mapping of MlirExternalPass //---------------------------------------------------------------------------- @@ -138,6 +145,14 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { mlirPassManagerEnableTiming(passManager.get()); }, "Enable pass timing.") + .def( + "enable_statistics", + [](PyPassManager &passManager, MlirPassDisplayMode displayMode) { + mlirPassManagerEnableStatistics(passManager.get(), displayMode); + }, + "displayMode"_a = + MlirPassDisplayMode::MLIR_PASS_DISPLAY_MODE_PIPELINE, + "Enable pass statistics.") .def_static( "parse", [](const std::string &pipeline, DefaultingPyMlirContext context) { diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index b0a6ec1ace3cc..72bec11f7c7de 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -13,6 +13,7 @@ #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" #include "mlir/Pass/PassManager.h" +#include "llvm/Support/ErrorHandling.h" #include using namespace mlir; @@ -79,6 +80,20 @@ void mlirPassManagerEnableTiming(MlirPassManager passManager) { unwrap(passManager)->enableTiming(); } +void mlirPassManagerEnableStatistics(MlirPassManager passManager, + MlirPassDisplayMode displayMode) { + PassDisplayMode mode; + switch (displayMode) { + case MLIR_PASS_DISPLAY_MODE_LIST: + mode = PassDisplayMode::List; + break; + case MLIR_PASS_DISPLAY_MODE_PIPELINE: + mode = PassDisplayMode::Pipeline; + break; + } + unwrap(passManager)->enableStatistics(mode); +} + MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, MlirStringRef operationName) { return wrap(&unwrap(passManager)->nest(unwrap(operationName))); diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index 5f92f5b52a09a..8e6208e142b13 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -435,3 +435,23 @@ def print_file_tree(directory, prefix=""): print_file_tree(temp_dir) log("// Tree printing end") + + +# CHECK-LABEL: TEST: testEnableStatistics +@run +def testEnableStatistics(): + with Context() as ctx: + module = ModuleOp.parse( + """ + module { + func.func @main() { + %0 = arith.constant 10 + return + } + } + """ + ) + pm = PassManager.parse("builtin.module(canonicalize)") + pm.enable_statistics() + # CHECK: Pass statistics report + pm.run(module)