Skip to content

Commit f26db3f

Browse files
authored
[MLIR] Add a OpWithFlags class that acts as a "stream modifier" to customize Operation streaming (#150636)
1 parent fcbcfe4 commit f26db3f

File tree

4 files changed

+32
-8
lines changed

4 files changed

+32
-8
lines changed

mlir/include/mlir/IR/Operation.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,29 @@ inline raw_ostream &operator<<(raw_ostream &os, const Operation &op) {
11021102
return os;
11031103
}
11041104

1105+
/// A wrapper class that allows for printing an operation with a set of flags,
1106+
/// useful to act as a "stream modifier" to customize printing an operation
1107+
/// with a stream using the operator<< overload, e.g.:
1108+
/// llvm::dbgs() << OpWithFlags(op, OpPrintingFlags().skipRegions());
1109+
class OpWithFlags {
1110+
public:
1111+
OpWithFlags(Operation *op, OpPrintingFlags flags = {})
1112+
: op(op), theFlags(flags) {}
1113+
OpPrintingFlags &flags() { return theFlags; }
1114+
const OpPrintingFlags &flags() const { return theFlags; }
1115+
1116+
private:
1117+
Operation *op;
1118+
OpPrintingFlags theFlags;
1119+
friend raw_ostream &operator<<(raw_ostream &os, const OpWithFlags &op);
1120+
};
1121+
1122+
inline raw_ostream &operator<<(raw_ostream &os,
1123+
const OpWithFlags &opWithFlags) {
1124+
opWithFlags.op->print(os, opWithFlags.flags());
1125+
return os;
1126+
}
1127+
11051128
} // namespace mlir
11061129

11071130
namespace llvm {

mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,8 @@ ChangeResult Liveness::meet(const AbstractSparseLattice &other) {
8080
LogicalResult
8181
LivenessAnalysis::visitOperation(Operation *op, ArrayRef<Liveness *> operands,
8282
ArrayRef<const Liveness *> results) {
83-
LLVM_DEBUG(DBGS() << "[visitOperation] Enter: ";
84-
op->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
85-
llvm::dbgs() << "\n");
83+
LDBG() << "[visitOperation] Enter: "
84+
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
8685
// This marks values of type (1.a) and (4) liveness as "live".
8786
if (!isMemoryEffectFree(op) || op->hasTrait<OpTrait::ReturnLike>()) {
8887
LDBG() << "[visitOperation] Operation has memory effects or is "

mlir/lib/Transforms/RemoveDeadValues.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "mlir/IR/Builders.h"
3737
#include "mlir/IR/BuiltinAttributes.h"
3838
#include "mlir/IR/Dialect.h"
39+
#include "mlir/IR/Operation.h"
3940
#include "mlir/IR/OperationSupport.h"
4041
#include "mlir/IR/SymbolTable.h"
4142
#include "mlir/IR/Value.h"
@@ -411,9 +412,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
411412
RunLivenessAnalysis &la,
412413
DenseSet<Value> &nonLiveSet,
413414
RDVFinalCleanupList &cl) {
414-
LLVM_DEBUG(DBGS() << "Processing region branch op: "; regionBranchOp->print(
415-
llvm::dbgs(), OpPrintingFlags().skipRegions());
416-
llvm::dbgs() << "\n");
415+
LDBG() << "Processing region branch op: "
416+
<< OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
417417
// Mark live results of `regionBranchOp` in `liveResults`.
418418
auto markLiveResults = [&](BitVector &liveResults) {
419419
liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/Dominance.h"
1515
#include "mlir/IR/IRMapping.h"
1616
#include "mlir/IR/Iterators.h"
17+
#include "mlir/IR/Operation.h"
1718
#include "mlir/Interfaces/FunctionInterfaces.h"
1819
#include "mlir/Rewrite/PatternApplicator.h"
1920
#include "llvm/ADT/SmallPtrSet.h"
@@ -2092,8 +2093,9 @@ OperationLegalizer::legalize(Operation *op,
20922093

20932094
// If the operation has no regions, just print it here.
20942095
if (!isIgnored && op->getNumRegions() == 0) {
2095-
op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
2096-
logger.getOStream() << "\n\n";
2096+
logger.startLine() << OpWithFlags(op,
2097+
OpPrintingFlags().printGenericOpForm())
2098+
<< "\n";
20972099
}
20982100
});
20992101

0 commit comments

Comments
 (0)