Skip to content

Commit 835ad09

Browse files
committed
[mlir][pybind] export more options on enable_ir_printing() api
1 parent e8e6795 commit 835ad09

File tree

8 files changed

+88
-7
lines changed

8 files changed

+88
-7
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
5656
DEFINE_C_API_STRUCT(MlirOperation, void);
5757
DEFINE_C_API_STRUCT(MlirOpOperand, void);
5858
DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
59+
DEFINE_C_API_STRUCT(MlirIRPrinterConfig, void);
5960
DEFINE_C_API_STRUCT(MlirBlock, void);
6061
DEFINE_C_API_STRUCT(MlirRegion, void);
6162
DEFINE_C_API_STRUCT(MlirSymbolTable, void);
@@ -450,6 +451,14 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags);
450451
MLIR_CAPI_EXPORTED void
451452
mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags);
452453

454+
//===----------------------------------------------------------------------===//
455+
// IR Printing config API.
456+
//===----------------------------------------------------------------------===//
457+
458+
MLIR_CAPI_EXPORTED MlirIRPrinterConfig mlirIRPrinterConfigCreate(void);
459+
460+
MLIR_CAPI_EXPORTED void mlirIRPrinterConfigDestroy(MlirIRPrinterConfig config);
461+
453462
//===----------------------------------------------------------------------===//
454463
// Bytecode printing flags API.
455464
//===----------------------------------------------------------------------===//

mlir/include/mlir-c/Pass.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,14 @@ MLIR_CAPI_EXPORTED MlirLogicalResult
7575
mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
7676

7777
/// Enable mlir-print-ir-after-all.
78+
// MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
79+
// MlirPassManager passManager, bool printBeforePass, bool printAfterPass,
80+
// bool printModuleScope, bool printAfterOnlyOnChange,
81+
// bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags);
82+
7883
MLIR_CAPI_EXPORTED void
79-
mlirPassManagerEnableIRPrinting(MlirPassManager passManager);
84+
mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
85+
MlirIRPrinterConfig config);
8086

8187
/// Enable / disable verify-each.
8288
MLIR_CAPI_EXPORTED void

mlir/include/mlir/CAPI/IR.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/IR/BuiltinOps.h"
2121
#include "mlir/IR/MLIRContext.h"
2222
#include "mlir/IR/Operation.h"
23+
#include "mlir/Pass/PassManager.h"
2324

2425
DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState)
2526
DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig)
@@ -30,6 +31,8 @@ DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
3031
DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
3132
DEFINE_C_API_PTR_METHODS(MlirOpOperand, mlir::OpOperand)
3233
DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)
34+
DEFINE_C_API_PTR_METHODS(MlirIRPrinterConfig,
35+
mlir::PassManager::IRPrinterConfig)
3336
DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
3437
DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable)
3538

mlir/include/mlir/Pass/PassManager.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ class PassManager : public OpPassManager {
307307
/// IR.
308308
explicit IRPrinterConfig(
309309
bool printModuleScope = false, bool printAfterOnlyOnChange = false,
310-
bool printAfterOnlyOnFailure = false,
310+
bool printAfterOnlyOnFailure = false, bool printBeforePass = false,
311+
bool printAfterPass = false,
311312
OpPrintingFlags opPrintingFlags = OpPrintingFlags());
312313
virtual ~IRPrinterConfig();
313314

@@ -338,6 +339,10 @@ class PassManager : public OpPassManager {
338339
return printAfterOnlyOnFailure;
339340
}
340341

342+
bool shouldPrintBeforePass() const { return printBeforePass; }
343+
344+
bool shouldPrintAfterPass() const { return printAfterPass; }
345+
341346
/// Returns the printing flags to be used to print the IR.
342347
OpPrintingFlags getOpPrintingFlags() const { return opPrintingFlags; }
343348

@@ -353,6 +358,10 @@ class PassManager : public OpPassManager {
353358
/// the pass failed.
354359
bool printAfterOnlyOnFailure;
355360

361+
bool printBeforePass;
362+
363+
bool printAfterPass;
364+
356365
/// Flags to control printing behavior.
357366
OpPrintingFlags opPrintingFlags;
358367
};

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,38 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
7474
"Releases (leaks) the backing pass manager (testing)")
7575
.def(
7676
"enable_ir_printing",
77-
[](PyPassManager &passManager) {
78-
mlirPassManagerEnableIRPrinting(passManager.get());
77+
[](PyPassManager &passManager, bool printBeforePass,
78+
bool printAfterPass, bool printModuleScope,
79+
bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure,
80+
std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
81+
bool printGenericOpForm) {
82+
MlirIRPrinterConfig config = mlirIRPrinterConfigCreate();
83+
mlirPassManagerEnableIRPrinting(passManager.get(), config);
84+
mlirIRPrinterConfigDestroy(config);
85+
// MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
86+
// if (largeElementsLimit)
87+
// mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
88+
// *largeElementsLimit);
89+
// if (enableDebugInfo)
90+
// mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
91+
// /*prettyForm=*/false);
92+
// if (printGenericOpForm)
93+
// mlirOpPrintingFlagsPrintGenericOpForm(flags);
94+
// mlirPassManagerEnableIRPrinting(passManager.get(),
95+
// printBeforePass,
96+
// printAfterPass, printModuleScope,
97+
// printAfterOnlyOnChange,
98+
// printAfterOnlyOnFailure, flags);
99+
// mlirOpPrintingFlagsDestroy(flags);
79100
},
101+
py::arg("print_before_pass") = true,
102+
py::arg("print_after_pass") = true,
103+
py::arg("print_module_scope") = true,
104+
py::arg("print_after_only_on_change") = true,
105+
py::arg("print_after_only_on_failure") = false,
106+
py::arg("large_elements_limit") = py::none(),
107+
py::arg("enable_debug_info") = false,
108+
py::arg("print_generic_op_form") = false,
80109
"Enable mlir-print-ir-after-all.")
81110
.def(
82111
"enable_verifier",

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,14 @@ void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) {
219219
unwrap(flags)->assumeVerified();
220220
}
221221

222+
MlirIRPrinterConfig mlirIRPrinterConfigCreate() {
223+
return wrap(new PassManager::IRPrinterConfig());
224+
}
225+
226+
void mlirIRPrinterConfigDestroy(MlirIRPrinterConfig config) {
227+
delete unwrap(config);
228+
}
229+
222230
//===----------------------------------------------------------------------===//
223231
// Bytecode printing flags API.
224232
//===----------------------------------------------------------------------===//

mlir/lib/CAPI/IR/Pass.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/CAPI/Support.h"
1414
#include "mlir/CAPI/Utils.h"
1515
#include "mlir/Pass/PassManager.h"
16+
#include <functional>
1617
#include <optional>
1718

1819
using namespace mlir;
@@ -44,8 +45,21 @@ MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
4445
return wrap(unwrap(passManager)->run(unwrap(op)));
4546
}
4647

47-
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
48-
return unwrap(passManager)->enableIRPrinting();
48+
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
49+
MlirIRPrinterConfig config) {
50+
std::function<bool(Pass *, Operation *)> shouldPrintBeforePass = nullptr;
51+
std::function<bool(Pass *, Operation *)> shouldPrintAfterPass = nullptr;
52+
if (unwrap(config)->shouldPrintBeforePass())
53+
shouldPrintBeforePass = [](Pass *, Operation *) { return true; };
54+
if (unwrap(config)->shouldPrintAfterPass())
55+
shouldPrintAfterPass = [](Pass *, Operation *) { return true; };
56+
return unwrap(passManager)
57+
->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
58+
unwrap(config)->shouldPrintAtModuleScope(),
59+
unwrap(config)->shouldPrintAfterOnlyOnChange(),
60+
unwrap(config)->shouldPrintAfterOnlyOnFailure(),
61+
/*out=*/llvm::errs(),
62+
unwrap(config)->getOpPrintingFlags());
4963
}
5064

5165
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {

mlir/lib/Pass/IRPrinting.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,13 @@ void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
133133
PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope,
134134
bool printAfterOnlyOnChange,
135135
bool printAfterOnlyOnFailure,
136+
bool printBeforePass,
137+
bool printAfterPass,
136138
OpPrintingFlags opPrintingFlags)
137139
: printModuleScope(printModuleScope),
138140
printAfterOnlyOnChange(printAfterOnlyOnChange),
139141
printAfterOnlyOnFailure(printAfterOnlyOnFailure),
142+
printBeforePass(printBeforePass), printAfterPass(printAfterPass),
140143
opPrintingFlags(opPrintingFlags) {}
141144
PassManager::IRPrinterConfig::~IRPrinterConfig() = default;
142145

@@ -172,7 +175,7 @@ struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
172175
bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags,
173176
raw_ostream &out)
174177
: IRPrinterConfig(printModuleScope, printAfterOnlyOnChange,
175-
printAfterOnlyOnFailure, opPrintingFlags),
178+
printAfterOnlyOnFailure, false, false, opPrintingFlags),
176179
shouldPrintBeforePass(std::move(shouldPrintBeforePass)),
177180
shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) {
178181
assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) &&

0 commit comments

Comments
 (0)