Skip to content

Commit 2ce655c

Browse files
authored
[mlir][func] Fix multiple bugs in DuplicateFunctionElimination (#109571)
This PR fixes multiple bugs in `DuplicateFunctionElimination`. - Prevents elimination of function declarations. - Updates all symbol uses to reference unique function representatives. Fixes #93483.
1 parent b8930cd commit 2ce655c

File tree

2 files changed

+58
-7
lines changed

2 files changed

+58
-7
lines changed

mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ struct DuplicateFuncOpEquivalenceInfo
5454
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
5555
rhs == getTombstoneKey() || rhs == getEmptyKey())
5656
return false;
57+
58+
if (lhs.isDeclaration() || rhs.isDeclaration())
59+
return false;
60+
5761
// Check discardable attributes equivalence
5862
if (lhs->getDiscardableAttrDictionary() !=
5963
rhs->getDiscardableAttrDictionary())
@@ -97,14 +101,14 @@ struct DuplicateFunctionEliminationPass
97101
}
98102
});
99103

100-
// Update call ops to call unique func op representants.
101-
module.walk([&](func::CallOp callOp) {
102-
func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()];
103-
callOp.setCallee(callee.getSymName());
104-
});
105-
106-
// Erase redundant func ops.
104+
// Update all symbol uses to reference unique func op
105+
// representants and erase redundant func ops.
106+
SymbolTableCollection symbolTable;
107+
SymbolUserMap userMap(symbolTable, module);
107108
for (auto it : toBeErased) {
109+
StringAttr oldSymbol = it.getSymNameAttr();
110+
StringAttr newSymbol = getRepresentant[oldSymbol].getSymNameAttr();
111+
userMap.replaceAllUsesWith(it, newSymbol);
108112
it.erase();
109113
}
110114
}

mlir/test/Dialect/Func/duplicate-function-elimination.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,50 @@ func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32)
366366
// CHECK: @user
367367
// CHECK-2: call @deep_tree
368368
// CHECK: call @reverse_deep_tree
369+
370+
// -----
371+
372+
func.func private @func_declaration(i32, i32) -> i32
373+
func.func private @func_declaration1(i32, i32) -> i32
374+
375+
func.func @user(%arg0: i32, %arg1: i32) -> (i32, i32) {
376+
%0 = call @func_declaration(%arg0, %arg1) : (i32, i32) -> i32
377+
%1 = call @func_declaration1(%arg0, %arg1) : (i32, i32) -> i32
378+
return %0, %1 : i32, i32
379+
}
380+
381+
// CHECK: @func_declaration
382+
// CHECK: @func_declaration1
383+
// CHECK: @user
384+
// CHECK: call @func_declaration
385+
// CHECK: call @func_declaration1
386+
387+
// -----
388+
389+
func.func @identity(%arg0: tensor<f32>) -> tensor<f32> {
390+
return %arg0 : tensor<f32>
391+
}
392+
393+
func.func @also_identity(%arg0: tensor<f32>) -> tensor<f32> {
394+
return %arg0 : tensor<f32>
395+
}
396+
397+
func.func @yet_another_identity(%arg0: tensor<f32>) -> tensor<f32> {
398+
return %arg0 : tensor<f32>
399+
}
400+
401+
func.func @user(%arg0: tensor<f32>) -> tensor<f32> {
402+
%f = constant @identity : (tensor<f32>) -> tensor<f32>
403+
%0 = call_indirect %f(%arg0) : (tensor<f32>) -> tensor<f32>
404+
%f_0 = constant @also_identity : (tensor<f32>) -> tensor<f32>
405+
%1 = call_indirect %f_0(%0) : (tensor<f32>) -> tensor<f32>
406+
%2 = call @yet_another_identity(%1) : (tensor<f32>) -> tensor<f32>
407+
return %2 : tensor<f32>
408+
}
409+
410+
// CHECK: @identity
411+
// CHECK-NOT: @also_identity
412+
// CHECK-NOT: @yet_another_identity
413+
// CHECK: @user
414+
// CHECK-2: constant @identity
415+
// CHECK: call @identity

0 commit comments

Comments
 (0)