Skip to content

Commit 4904fb9

Browse files
committed
Add LIT test contains SCF
This diff also tries to propagate liveness recursively.
1 parent 3f4d69f commit 4904fb9

File tree

3 files changed

+99
-19
lines changed

3 files changed

+99
-19
lines changed

mlir/include/mlir/IR/Visitors.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ struct BackwardIterator {
4545
static auto makeIterable(T &range) {
4646
if constexpr (std::is_same<T, Operation>()) {
4747
/// Make operations iterable: return the list of regions.
48-
return llvm::reverse(range.getRegions());
48+
return range.getRegions();
4949
} else {
5050
/// Regions and block are already iterable.
5151
return llvm::reverse(range);

mlir/lib/Transforms/RemoveDeadValues.cpp

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -269,16 +269,6 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
269269
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
270270
DenseSet<Value> &nonLiveSet,
271271
DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
272-
for (Value val : op->getResults()) {
273-
if (liveSet.contains(val)) {
274-
LDBG() << "Simple op is used by a public function, "
275-
"preserving it: "
276-
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
277-
liveSet.insert_range(op->getOperands());
278-
return;
279-
}
280-
}
281-
282272
if (!isMemoryEffectFree(op) ||
283273
hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
284274
LDBG() << "Simple op is not memory effect free or has live results, "
@@ -412,6 +402,82 @@ static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
412402
return {};
413403
}
414404

405+
// When you mark a call operand as live, also mark its definition chain, recursively.
406+
// We handle RegionBranchOpInterface here. I think we should handle BranchOpInterface as well.
407+
void propagateBackward(Value val, DenseSet<Value> &liveSet) {
408+
if (liveSet.contains(val)) return;
409+
liveSet.insert(val);
410+
411+
if (auto defOp = val.getDefiningOp()) {
412+
// Mark operands of live results as live
413+
for (Value operand : defOp->getOperands()) {
414+
propagateBackward(operand, liveSet);
415+
}
416+
417+
// Handle RegionBranchOpInterface specially
418+
if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(defOp)) {
419+
// If this is a result of a RegionBranchOpInterface, we need to trace back
420+
// through the control flow to find the sources that contribute to this result
421+
422+
OpResult result = cast<OpResult>(val);
423+
unsigned resultIndex = result.getResultNumber();
424+
425+
// Find all possible sources that can contribute to this result
426+
// by examining all regions and their terminators
427+
for (Region &region : regionBranchOp->getRegions()) {
428+
if (region.empty()) continue;
429+
430+
// Get the successors from this region
431+
SmallVector<RegionSuccessor> successors;
432+
regionBranchOp.getSuccessorRegions(RegionBranchPoint(&region), successors);
433+
434+
// Check if any successor can produce this result
435+
for (const RegionSuccessor &successor : successors) {
436+
if (successor.isParent()) {
437+
// This region can return to the parent operation
438+
ValueRange successorInputs = successor.getSuccessorInputs();
439+
if (resultIndex < successorInputs.size()) {
440+
// Find the terminator that contributes to this result
441+
Operation *terminator = region.back().getTerminator();
442+
if (auto regionBranchTerm =
443+
dyn_cast<RegionBranchTerminatorOpInterface>(terminator)) {
444+
OperandRange terminatorOperands =
445+
regionBranchTerm.getSuccessorOperands(RegionBranchPoint::parent());
446+
if (resultIndex < terminatorOperands.size()) {
447+
// This terminator operand contributes to our result
448+
propagateBackward(terminatorOperands[resultIndex], liveSet);
449+
}
450+
}
451+
}
452+
}
453+
}
454+
455+
// Also mark region arguments as live if they might contribute to this result
456+
// Find which operand of the parent operation corresponds to region arguments
457+
Block &entryBlock = region.front();
458+
for (BlockArgument arg : entryBlock.getArguments()) {
459+
// Get entry successor operands - these are the operands that flow
460+
// from the parent operation to this region
461+
SmallVector<RegionSuccessor> entrySuccessors;
462+
regionBranchOp.getSuccessorRegions(RegionBranchPoint::parent(), entrySuccessors);
463+
464+
for (const RegionSuccessor &entrySuccessor : entrySuccessors) {
465+
if (entrySuccessor.getSuccessor() == &region) {
466+
// Get the operands that are forwarded to this region
467+
OperandRange entryOperands =
468+
regionBranchOp.getEntrySuccessorOperands(RegionBranchPoint::parent());
469+
unsigned argIndex = arg.getArgNumber();
470+
if (argIndex < entryOperands.size()) {
471+
propagateBackward(entryOperands[argIndex], liveSet);
472+
}
473+
break;
474+
}
475+
}
476+
}
477+
}
478+
}
479+
}
480+
}
415481
static void processCallOp(CallOpInterface callOp, Operation *module,
416482
RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
417483
DenseSet<Value> &liveSet) {
@@ -439,13 +505,8 @@ static void processCallOp(CallOpInterface callOp, Operation *module,
439505

440506
for (int index : nonLiveArgs.set_bits()) {
441507
OpOperand *operand = callOpOperands[index];
442-
Value oldVal = operand->get();
443-
if (Value dummy = createDummyArgument(callOp, oldVal)) {
444-
callOp->setOperand(operand->getOperandNumber(), dummy);
445-
nonLiveSet.insert(oldVal);
446-
} else {
447-
liveSet.insert(oldVal);
448-
}
508+
LDBG() << "mark operand " << index << " live " << operand->get();
509+
propagateBackward(operand->get(), liveSet);
449510
}
450511
}
451512
}

mlir/test/Transforms/remove-dead-values.mlir

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,8 +576,9 @@ module @return_void_with_unused_argument {
576576
}
577577

578578
// CHECK-LABEL: func.func @main2
579+
// CHECK: %[[ONE:.*]] = arith.constant 1 : i32
580+
// CHECK: %[[UNUSED:.*]] = arith.addi %[[ONE]], %[[ONE]] : i32
579581
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
580-
// CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
581582
// CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
582583
func.func @main2() -> () {
583584
%one = arith.constant 1 : i32
@@ -587,6 +588,24 @@ module @return_void_with_unused_argument {
587588
call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
588589
return
589590
}
591+
592+
// CHECK-LABEL: func.func @main3
593+
// CHECK: %[[UNUSED:.*]] = scf.if %arg0 -> (i32)
594+
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
595+
// CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
596+
func.func @main3(%arg0: i1) {
597+
%0 = scf.if %arg0 -> (i32) {
598+
%c1_i32 = arith.constant 1 : i32
599+
scf.yield %c1_i32 : i32
600+
} else {
601+
%c0_i32 = arith.constant 0 : i32
602+
scf.yield %c0_i32 : i32
603+
}
604+
%mem = memref.alloc() : memref<4xf32>
605+
606+
call @immutable_fn_with_unused_argument(%0, %mem) : (i32, memref<4xf32>) -> ()
607+
return
608+
}
590609
}
591610

592611
// -----

0 commit comments

Comments
 (0)