Skip to content

Commit e6110cb

Browse files
[mlir][Transforms] Fix crash in -remove-dead-values on private functions (#169269)
This commit fixes two crashes in the `-remove-dead-values` pass related to private functions. Private functions are considered entirely "dead" by the liveness analysis, which drives the `-remove-dead-values` pass. The `-remove-dead-values` pass removes dead block arguments from private functions. Private functions are entirely dead, so all of their block arguments are removed. However, the pass did not correctly update all users of these dropped block arguments. 1. A side-effecting operation must be removed if one of its operands is dead. Otherwise, the operation would end up with a NULL operand. Note: The liveness analysis would not have marked an SSA value as "dead" if it had a reachable side-effecting users. (Therefore, it is safe to erase such side-effecting operations.) 2. A branch operation must be removed if one of its non-forwarded operands is dead. (E.g., the condition value of a `cf.cond_br`.) Whenever a terminator is removed, a `ub.unrechable` operation is inserted. This fixes #158760.
1 parent 30f479f commit e6110cb

File tree

4 files changed

+74
-2
lines changed

4 files changed

+74
-2
lines changed

mlir/include/mlir/Transforms/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def RemoveDeadValues : Pass<"remove-dead-values"> {
248248
```
249249
}];
250250
let constructor = "mlir::createRemoveDeadValuesPass()";
251+
let dependentDialects = ["ub::UBDialect"];
251252
}
252253

253254
def PrintIRPass : Pass<"print-ir"> {

mlir/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@ add_mlir_library(MLIRTransforms
3939
MLIRSideEffectInterfaces
4040
MLIRSupport
4141
MLIRTransformUtils
42+
MLIRUBDialect
4243
)

mlir/lib/Transforms/RemoveDeadValues.cpp

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
3535
#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
36+
#include "mlir/Dialect/UB/IR/UBOps.h"
3637
#include "mlir/IR/Builders.h"
3738
#include "mlir/IR/BuiltinAttributes.h"
3839
#include "mlir/IR/Dialect.h"
@@ -260,6 +261,22 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
260261
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
261262
DenseSet<Value> &nonLiveSet,
262263
RDVFinalCleanupList &cl) {
264+
// Operations that have dead operands can be erased regardless of their
265+
// side effects. The liveness analysis would not have marked an SSA value as
266+
// "dead" if it had a side-effecting user that is reachable.
267+
bool hasDeadOperand =
268+
markLives(op->getOperands(), nonLiveSet, la).flip().any();
269+
if (hasDeadOperand) {
270+
LDBG() << "Simple op has dead operands, so the op must be dead: "
271+
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
272+
assert(!hasLive(op->getResults(), nonLiveSet, la) &&
273+
"expected the op to have no live results");
274+
cl.operations.push_back(op);
275+
collectNonLiveValues(nonLiveSet, op->getResults(),
276+
BitVector(op->getNumResults(), true));
277+
return;
278+
}
279+
263280
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
264281
LDBG() << "Simple op is not memory effect free or has live results, "
265282
"preserving it: "
@@ -361,6 +378,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
361378
// block other than the entry block, because every block has a terminator.
362379
for (Block &block : funcOp.getBlocks()) {
363380
Operation *returnOp = block.getTerminator();
381+
if (!returnOp->hasTrait<OpTrait::ReturnLike>())
382+
continue;
364383
if (returnOp && returnOp->getNumOperands() == numReturns)
365384
cl.operands.push_back({returnOp, nonLiveRets});
366385
}
@@ -700,7 +719,11 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
700719
}
701720

702721
/// Steps to process a `BranchOpInterface` operation:
703-
/// Iterate through each successor block of `branchOp`.
722+
///
723+
/// When a non-forwarded operand is dead (e.g., the condition value of a
724+
/// conditional branch op), the entire operation is dead.
725+
///
726+
/// Otherwise, iterate through each successor block of `branchOp`.
704727
/// (1) For each successor block, gather all operands from all successors.
705728
/// (2) Fetch their associated liveness analysis data and collect for future
706729
/// removal.
@@ -711,7 +734,22 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
711734
DenseSet<Value> &nonLiveSet,
712735
RDVFinalCleanupList &cl) {
713736
LDBG() << "Processing branch op: " << *branchOp;
737+
738+
// Check for dead non-forwarded operands.
739+
BitVector deadNonForwardedOperands =
740+
markLives(branchOp->getOperands(), nonLiveSet, la).flip();
714741
unsigned numSuccessors = branchOp->getNumSuccessors();
742+
for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
743+
SuccessorOperands successorOperands =
744+
branchOp.getSuccessorOperands(succIdx);
745+
// Remove all non-forwarded operands from the bit vector.
746+
for (OpOperand &opOperand : successorOperands.getMutableForwardedOperands())
747+
deadNonForwardedOperands[opOperand.getOperandNumber()] = false;
748+
}
749+
if (deadNonForwardedOperands.any()) {
750+
cl.operations.push_back(branchOp.getOperation());
751+
return;
752+
}
715753

716754
for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
717755
Block *successorBlock = branchOp->getSuccessor(succIdx);
@@ -786,9 +824,14 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
786824

787825
// 3. Operations
788826
LDBG() << "Cleaning up " << list.operations.size() << " operations";
789-
for (auto &op : list.operations) {
827+
for (Operation *op : list.operations) {
790828
LDBG() << "Erasing operation: "
791829
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
830+
if (op->hasTrait<OpTrait::IsTerminator>()) {
831+
// When erasing a terminator, insert an unreachable op in its place.
832+
OpBuilder b(op);
833+
ub::UnreachableOp::create(b, op->getLoc());
834+
}
792835
op->dropAllUses();
793836
op->erase();
794837
}

mlir/test/Transforms/remove-dead-values.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,17 @@ func.func @main(%arg0 : i32) {
118118

119119
// -----
120120

121+
// CHECK-LABEL: func.func private @clean_func_op_remove_side_effecting_op() {
122+
// CHECK-NEXT: return
123+
// CHECK-NEXT: }
124+
func.func private @clean_func_op_remove_side_effecting_op(%arg0: i32) -> (i32) {
125+
// vector.print has a side effect but the op is dead.
126+
vector.print %arg0 : i32
127+
return %arg0 : i32
128+
}
129+
130+
// -----
131+
121132
// %arg0 is not live because it is never used. %arg1 is not live because its
122133
// user `arith.addi` doesn't have any uses and the value that it is forwarded to
123134
// (%non_live_0) also doesn't have any uses.
@@ -687,3 +698,19 @@ func.func @op_block_have_dead_arg(%arg0: index, %arg1: index, %arg2: i1) {
687698
// CHECK-NEXT: return
688699
return
689700
}
701+
702+
// -----
703+
704+
// CHECK-LABEL: func private @remove_dead_branch_op()
705+
// CHECK-NEXT: ub.unreachable
706+
// CHECK-NEXT: ^{{.*}}:
707+
// CHECK-NEXT: return
708+
// CHECK-NEXT: ^{{.*}}:
709+
// CHECK-NEXT: return
710+
func.func private @remove_dead_branch_op(%c: i1, %arg0: i64, %arg1: i64) -> (i64) {
711+
cf.cond_br %c, ^bb1, ^bb2
712+
^bb1:
713+
return %arg0 : i64
714+
^bb2:
715+
return %arg1 : i64
716+
}

0 commit comments

Comments
 (0)