Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions mlir/lib/Transforms/RemoveDeadValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,30 @@ static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
op->erase();
}

// Remove the dead functions from moduleOp.
static void deleteDeadFunction(Operation *module) {
auto functions = module->getRegion(0).getOps<FunctionOpInterface>();
llvm::DenseSet<FunctionOpInterface> tasks(functions.begin(), functions.end());
while (!tasks.empty()) {
llvm::DenseSet<FunctionOpInterface> nextTasks;
for (FunctionOpInterface funcOp : tasks) {
if (funcOp.isPublic() || funcOp.isExternal())
return;
SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this super expensive to build?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It certainly feels super expensive.I've updated the code so that it now only runs on functions that are likely to be dead functions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I accidentally discovered that the -symbol-dce pass includes the functionality of this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was mentioned to you earlier actually: #161471 (review)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must admit I wasn't particularly well-versed in many of the MLIR passes, but I've gained a much better understanding now. Ha ha.😘
Actually, I just took another look at the code for this pass. The most immediate issue stems from directly deleting the operands of cond_br. The underlying problem remains the data flow analysis issue mentioned earlier.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit of a pity that -symbol-dce isn't implemented via patterns; otherwise, it might have been possible to incorporate it into remove-dead-values.To be perfectly honest, I know what the best fix is.Indeed, deleting dead functions is not the best approach.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think perhaps we could wrap the logic of symbol-dce in a function and then use it in remove-dead-values.

auto callSites = funcOp.getFunctionBody().getOps<CallOpInterface>();
if (uses.empty() && !callSites.empty()) {
for (CallOpInterface callOp : callSites) {
nextTasks.insert(cast<FunctionOpInterface>(callOp.resolveCallable()));
}
}

if (uses.empty() && !nextTasks.contains(funcOp))
funcOp.erase();
}
tasks = nextTasks;
}
}

/// Convert a list of `Operand`s to a list of `OpOperand`s.
static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
OpOperand *values = operands.getBase();
Expand Down Expand Up @@ -881,6 +905,8 @@ void RemoveDeadValues::runOnOperation() {
// end of this pass.
RDVFinalCleanupList finalCleanupList;

// Remove the dead function in advance.
deleteDeadFunction(module);
module->walk([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
Expand Down
48 changes: 47 additions & 1 deletion mlir/test/Transforms/remove-dead-values.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ module @llvm_unreachable {
func.func private @fn_with_llvm_unreachable(%arg0: tensor<4x4xf32>) -> tensor<4x4xi1> {
llvm.unreachable
}
func.func private @main(%arg0: tensor<4x4xf32>) {
func.func @main(%arg0: tensor<4x4xf32>) {
%0 = call @fn_with_llvm_unreachable(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xi1>
llvm.return
}
Expand Down Expand Up @@ -649,3 +649,49 @@ func.func @callee(%arg0: index, %arg1: index, %arg2: index) -> index {
%res = call @mutl_parameter(%arg0, %arg1, %arg2) : (index, index, index) -> (index)
return %res : index
}

// -----

// Test the elimination of dead functions.

// CHECK-NOT: func private @single_private_func
func.func private @single_private_func(%arg0: i64) -> (i64) {
%c0_i64 = arith.constant 0 : i64
%2 = arith.cmpi eq, %arg0, %c0_i64 : i64
cf.cond_br %2, ^bb1, ^bb2
^bb1: // pred: ^bb0
%c1_i64 = arith.constant 1 : i64
return %c1_i64 : i64
^bb2: // pred: ^bb0
%c3_i64 = arith.constant 3 : i64
return %c3_i64 : i64
}

// -----

// Test the elimination of dead functions.

// CHECK-NOT: @single_parameter
func.func private @single_parameter(%arg0: index) {
return
}

// CHECK-NOT: @mutl_parameter
func.func private @mutl_parameter(%arg0: index, %arg1: index, %arg2: index) -> index {
return %arg1 : index
}

// CHECK-NOT: @eliminate_parameter
func.func private @eliminate_parameter(%arg0: index, %arg1: index) {
call @single_parameter(%arg0) : (index) -> ()
return
}

// CHECK-NOT: @callee
func.func private @callee(%arg0: index, %arg1: index, %arg2: index) -> index {
// CHECK-NOT: call @eliminate_parameter
call @eliminate_parameter(%arg0, %arg1) : (index, index) -> ()
// CHECK-NOT: call @mutl_parameter
%res = call @mutl_parameter(%arg0, %arg1, %arg2) : (index, index, index) -> (index)
return %res : index
}