From 835ad09a5fab61af1504779030763dd3f9a306da Mon Sep 17 00:00:00 2001 From: LiuYuanqiang Date: Sun, 10 Sep 2023 01:49:49 +0800 Subject: [PATCH] [mlir][pybind] export more options on enable_ir_printing() api --- mlir/include/mlir-c/IR.h | 9 ++++++++ mlir/include/mlir-c/Pass.h | 8 ++++++- mlir/include/mlir/CAPI/IR.h | 3 +++ mlir/include/mlir/Pass/PassManager.h | 11 +++++++++- mlir/lib/Bindings/Python/Pass.cpp | 33 ++++++++++++++++++++++++++-- mlir/lib/CAPI/IR/IR.cpp | 8 +++++++ mlir/lib/CAPI/IR/Pass.cpp | 18 +++++++++++++-- mlir/lib/Pass/IRPrinting.cpp | 5 ++++- 8 files changed, 88 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 82da511f807a3..422c3eb9cf58a 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -56,6 +56,7 @@ DEFINE_C_API_STRUCT(MlirDialectRegistry, void); DEFINE_C_API_STRUCT(MlirOperation, void); DEFINE_C_API_STRUCT(MlirOpOperand, void); DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void); +DEFINE_C_API_STRUCT(MlirIRPrinterConfig, void); DEFINE_C_API_STRUCT(MlirBlock, void); DEFINE_C_API_STRUCT(MlirRegion, void); DEFINE_C_API_STRUCT(MlirSymbolTable, void); @@ -450,6 +451,14 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags); MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags); +//===----------------------------------------------------------------------===// +// IR Printing config API. +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED MlirIRPrinterConfig mlirIRPrinterConfigCreate(void); + +MLIR_CAPI_EXPORTED void mlirIRPrinterConfigDestroy(MlirIRPrinterConfig config); + //===----------------------------------------------------------------------===// // Bytecode printing flags API. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 35db138305d1e..6c0fcb98409a5 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -75,8 +75,14 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op); /// Enable mlir-print-ir-after-all. +// MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting( +// MlirPassManager passManager, bool printBeforePass, bool printAfterPass, +// bool printModuleScope, bool printAfterOnlyOnChange, +// bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags); + MLIR_CAPI_EXPORTED void -mlirPassManagerEnableIRPrinting(MlirPassManager passManager); +mlirPassManagerEnableIRPrinting(MlirPassManager passManager, + MlirIRPrinterConfig config); /// Enable / disable verify-each. MLIR_CAPI_EXPORTED void diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index 1836cb0acb67e..488c6fb80836c 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -20,6 +20,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "mlir/Pass/PassManager.h" DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState) DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig) @@ -30,6 +31,8 @@ DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation) DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block) DEFINE_C_API_PTR_METHODS(MlirOpOperand, mlir::OpOperand) DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags) +DEFINE_C_API_PTR_METHODS(MlirIRPrinterConfig, + mlir::PassManager::IRPrinterConfig) DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region) DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable) diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index 1b2e6a3bc82bb..5e383bdf373dd 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -307,7 +307,8 @@ class PassManager : public OpPassManager { /// IR. explicit IRPrinterConfig( bool printModuleScope = false, bool printAfterOnlyOnChange = false, - bool printAfterOnlyOnFailure = false, + bool printAfterOnlyOnFailure = false, bool printBeforePass = false, + bool printAfterPass = false, OpPrintingFlags opPrintingFlags = OpPrintingFlags()); virtual ~IRPrinterConfig(); @@ -338,6 +339,10 @@ class PassManager : public OpPassManager { return printAfterOnlyOnFailure; } + bool shouldPrintBeforePass() const { return printBeforePass; } + + bool shouldPrintAfterPass() const { return printAfterPass; } + /// Returns the printing flags to be used to print the IR. OpPrintingFlags getOpPrintingFlags() const { return opPrintingFlags; } @@ -353,6 +358,10 @@ class PassManager : public OpPassManager { /// the pass failed. bool printAfterOnlyOnFailure; + bool printBeforePass; + + bool printAfterPass; + /// Flags to control printing behavior. OpPrintingFlags opPrintingFlags; }; diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index a68421b61641f..943cc1cbc19e6 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -74,9 +74,38 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { "Releases (leaks) the backing pass manager (testing)") .def( "enable_ir_printing", - [](PyPassManager &passManager) { - mlirPassManagerEnableIRPrinting(passManager.get()); + [](PyPassManager &passManager, bool printBeforePass, + bool printAfterPass, bool printModuleScope, + bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure, + std::optional largeElementsLimit, bool enableDebugInfo, + bool printGenericOpForm) { + MlirIRPrinterConfig config = mlirIRPrinterConfigCreate(); + mlirPassManagerEnableIRPrinting(passManager.get(), config); + mlirIRPrinterConfigDestroy(config); + // MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); + // if (largeElementsLimit) + // mlirOpPrintingFlagsElideLargeElementsAttrs(flags, + // *largeElementsLimit); + // if (enableDebugInfo) + // mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, + // /*prettyForm=*/false); + // if (printGenericOpForm) + // mlirOpPrintingFlagsPrintGenericOpForm(flags); + // mlirPassManagerEnableIRPrinting(passManager.get(), + // printBeforePass, + // printAfterPass, printModuleScope, + // printAfterOnlyOnChange, + // printAfterOnlyOnFailure, flags); + // mlirOpPrintingFlagsDestroy(flags); }, + py::arg("print_before_pass") = true, + py::arg("print_after_pass") = true, + py::arg("print_module_scope") = true, + py::arg("print_after_only_on_change") = true, + py::arg("print_after_only_on_failure") = false, + py::arg("large_elements_limit") = py::none(), + py::arg("enable_debug_info") = false, + py::arg("print_generic_op_form") = false, "Enable mlir-print-ir-after-all.") .def( "enable_verifier", diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index cdb64f4ec4a40..74f5211f5e991 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -219,6 +219,14 @@ void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) { unwrap(flags)->assumeVerified(); } +MlirIRPrinterConfig mlirIRPrinterConfigCreate() { + return wrap(new PassManager::IRPrinterConfig()); +} + +void mlirIRPrinterConfigDestroy(MlirIRPrinterConfig config) { + delete unwrap(config); +} + //===----------------------------------------------------------------------===// // Bytecode printing flags API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index d242baae99c08..0bd77036ccbf0 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -13,6 +13,7 @@ #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" #include "mlir/Pass/PassManager.h" +#include #include using namespace mlir; @@ -44,8 +45,21 @@ MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, return wrap(unwrap(passManager)->run(unwrap(op))); } -void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) { - return unwrap(passManager)->enableIRPrinting(); +void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, + MlirIRPrinterConfig config) { + std::function shouldPrintBeforePass = nullptr; + std::function shouldPrintAfterPass = nullptr; + if (unwrap(config)->shouldPrintBeforePass()) + shouldPrintBeforePass = [](Pass *, Operation *) { return true; }; + if (unwrap(config)->shouldPrintAfterPass()) + shouldPrintAfterPass = [](Pass *, Operation *) { return true; }; + return unwrap(passManager) + ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, + unwrap(config)->shouldPrintAtModuleScope(), + unwrap(config)->shouldPrintAfterOnlyOnChange(), + unwrap(config)->shouldPrintAfterOnlyOnFailure(), + /*out=*/llvm::errs(), + unwrap(config)->getOpPrintingFlags()); } void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index 72b94eeb0123f..a9d847c7354b7 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -133,10 +133,13 @@ void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope, bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure, + bool printBeforePass, + bool printAfterPass, OpPrintingFlags opPrintingFlags) : printModuleScope(printModuleScope), printAfterOnlyOnChange(printAfterOnlyOnChange), printAfterOnlyOnFailure(printAfterOnlyOnFailure), + printBeforePass(printBeforePass), printAfterPass(printAfterPass), opPrintingFlags(opPrintingFlags) {} PassManager::IRPrinterConfig::~IRPrinterConfig() = default; @@ -172,7 +175,7 @@ struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig { bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags, raw_ostream &out) : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, - printAfterOnlyOnFailure, opPrintingFlags), + printAfterOnlyOnFailure, false, false, opPrintingFlags), shouldPrintBeforePass(std::move(shouldPrintBeforePass)), shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) { assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) &&