diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp index d41d6c3e8972f..fbb6abfd65b10 100644 --- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp +++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp @@ -54,6 +54,10 @@ struct DuplicateFuncOpEquivalenceInfo if (lhs == getTombstoneKey() || lhs == getEmptyKey() || rhs == getTombstoneKey() || rhs == getEmptyKey()) return false; + + if (lhs.isDeclaration() || rhs.isDeclaration()) + return false; + // Check discardable attributes equivalence if (lhs->getDiscardableAttrDictionary() != rhs->getDiscardableAttrDictionary()) @@ -97,14 +101,14 @@ struct DuplicateFunctionEliminationPass } }); - // Update call ops to call unique func op representants. - module.walk([&](func::CallOp callOp) { - func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()]; - callOp.setCallee(callee.getSymName()); - }); - - // Erase redundant func ops. + // Update all symbol uses to reference unique func op + // representants and erase redundant func ops. + SymbolTableCollection symbolTable; + SymbolUserMap userMap(symbolTable, module); for (auto it : toBeErased) { + StringAttr oldSymbol = it.getSymNameAttr(); + StringAttr newSymbol = getRepresentant[oldSymbol].getSymNameAttr(); + userMap.replaceAllUsesWith(it, newSymbol); it.erase(); } } diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir index 28d059a149bde..1f8da78b6d63d 100644 --- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir +++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir @@ -366,3 +366,50 @@ func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32) // CHECK: @user // CHECK-2: call @deep_tree // CHECK: call @reverse_deep_tree + +// ----- + +func.func private @func_declaration(i32, i32) -> i32 +func.func private @func_declaration1(i32, i32) -> i32 + +func.func @user(%arg0: i32, %arg1: i32) -> (i32, i32) { + %0 = call @func_declaration(%arg0, %arg1) : (i32, i32) -> i32 + %1 = call @func_declaration1(%arg0, %arg1) : (i32, i32) -> i32 + return %0, %1 : i32, i32 +} + +// CHECK: @func_declaration +// CHECK: @func_declaration1 +// CHECK: @user +// CHECK: call @func_declaration +// CHECK: call @func_declaration1 + +// ----- + +func.func @identity(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + +func.func @also_identity(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + +func.func @yet_another_identity(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + +func.func @user(%arg0: tensor) -> tensor { + %f = constant @identity : (tensor) -> tensor + %0 = call_indirect %f(%arg0) : (tensor) -> tensor + %f_0 = constant @also_identity : (tensor) -> tensor + %1 = call_indirect %f_0(%0) : (tensor) -> tensor + %2 = call @yet_another_identity(%1) : (tensor) -> tensor + return %2 : tensor +} + +// CHECK: @identity +// CHECK-NOT: @also_identity +// CHECK-NOT: @yet_another_identity +// CHECK: @user +// CHECK-2: constant @identity +// CHECK: call @identity