Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
69e1cc7
new runOn method
Casperento Sep 24, 2024
b6b8b67
Merge branch 'main' into merge-functions
Casperento Nov 1, 2024
e391813
Merge branch 'main' into merge-functions
Casperento Nov 2, 2024
3df9ca1
fix format and stable_hash
Casperento Nov 2, 2024
587f210
Merge branch 'main' into merge-functions
Casperento Nov 3, 2024
58be745
Merge branch 'main' into merge-functions
Casperento Nov 4, 2024
72b0431
Merge branch 'main' into merge-functions
Casperento Nov 5, 2024
e22fd6c
Merge branch 'main' into merge-functions
Casperento Nov 5, 2024
67d70af
fix comments
Casperento Nov 6, 2024
b15b1bf
Merge branch 'main' into merge-functions
Casperento Nov 6, 2024
c338f3d
Merge branch 'main' into merge-functions
Casperento Nov 13, 2024
eae748a
fix comments
Casperento Nov 15, 2024
1fdf94e
fix format
Casperento Nov 15, 2024
d193a63
fix comment
Casperento Nov 16, 2024
c2fa33b
Merge branch 'main' into merge-functions
Casperento Nov 16, 2024
1c91d48
Merge branch 'main' into merge-functions
Casperento Nov 17, 2024
913f418
Merge branch 'main' into merge-functions
Casperento Nov 22, 2024
04c23d0
unused includes removed
Casperento Nov 22, 2024
7daef3c
Merge branch 'main' into merge-functions
Casperento Nov 27, 2024
350395d
comments fix
Casperento Nov 27, 2024
1948974
Merge branch 'main' into merge-functions
Casperento Nov 27, 2024
07aa7c1
comment fix
Casperento Nov 27, 2024
fb694a8
Merge branch 'main' into merge-functions
Casperento Nov 27, 2024
a6f6b3c
comment fix
Casperento Nov 28, 2024
97a440d
Merge branch 'main' into merge-functions
Casperento Nov 28, 2024
8703263
comment fix
Casperento Nov 28, 2024
f758d86
Merge branch 'main' into merge-functions
Casperento Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/include/llvm/Transforms/IPO/MergeFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@
namespace llvm {

class Module;
class Function;

/// Merge identical functions.
class MergeFunctionsPass : public PassInfoMixin<MergeFunctionsPass> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);

static bool runOnModule(Module &M);
static DenseMap<Function *, Function *>
runOnFunctions(ArrayRef<Function *> F);
};

} // end namespace llvm
Expand Down
61 changes: 45 additions & 16 deletions llvm/lib/Transforms/IPO/MergeFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,10 @@ class MergeFunctions {
MergeFunctions() : FnTree(FunctionNodeCmp(&GlobalNumbers)) {
}

bool runOnModule(Module &M);
template <typename FuncContainer> bool run(FuncContainer &Functions);
DenseMap<Function *, Function *> runOnFunctions(ArrayRef<Function *> F);

SmallPtrSet<GlobalValue *, 4> &getUsed();

private:
// The function comparison operator is provided here so that FunctionNodes do
Expand Down Expand Up @@ -297,17 +300,36 @@ class MergeFunctions {
// dangling iterators into FnTree. The invariant that preserves this is that
// there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
DenseMap<AssertingVH<Function>, FnTreeType::iterator> FNodesInTree;

/// Deleted-New functions mapping
DenseMap<Function *, Function *> DelToNewMap;
};
} // end anonymous namespace

PreservedAnalyses MergeFunctionsPass::run(Module &M,
ModuleAnalysisManager &AM) {
MergeFunctions MF;
if (!MF.runOnModule(M))
if (!MergeFunctionsPass::runOnModule(M))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
}

SmallPtrSet<GlobalValue *, 4> &MergeFunctions::getUsed() { return Used; }

bool MergeFunctionsPass::runOnModule(Module &M) {
MergeFunctions MF;
SmallVector<GlobalValue *, 4> UsedV;
collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/false);
collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/true);
MF.getUsed().insert(UsedV.begin(), UsedV.end());
return MF.run(M);
}

DenseMap<Function *, Function *>
MergeFunctionsPass::runOnFunctions(ArrayRef<Function *> F) {
MergeFunctions MF;
return MF.runOnFunctions(F);
}

#ifndef NDEBUG
bool MergeFunctions::doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist) {
if (const unsigned Max = NumFunctionsForVerificationCheck) {
Expand Down Expand Up @@ -409,20 +431,19 @@ static bool isEligibleForMerging(Function &F) {
!hasDistinctMetadataIntrinsic(F);
}

bool MergeFunctions::runOnModule(Module &M) {
bool Changed = false;
inline Function *asPtr(Function *Fn) { return Fn; }
inline Function *asPtr(Function &Fn) { return &Fn; }

SmallVector<GlobalValue *, 4> UsedV;
collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/false);
collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/true);
Used.insert(UsedV.begin(), UsedV.end());
template <typename FuncContainer> bool MergeFunctions::run(FuncContainer &M) {
bool Changed = false;

// All functions in the module, ordered by hash. Functions with a unique
// hash value are easily eliminated.
std::vector<std::pair<stable_hash, Function *>> HashedFuncs;
for (Function &Func : M) {
if (isEligibleForMerging(Func)) {
HashedFuncs.push_back({StructuralHash(Func), &Func});
for (auto &Func : M) {
Function *FuncPtr = asPtr(Func);
if (isEligibleForMerging(*FuncPtr)) {
HashedFuncs.push_back({StructuralHash(*FuncPtr), FuncPtr});
}
}

Expand All @@ -433,7 +454,7 @@ bool MergeFunctions::runOnModule(Module &M) {
// If the hash value matches the previous value or the next one, we must
// consider merging it. Otherwise it is dropped and never considered again.
if ((I != S && std::prev(I)->first == I->first) ||
(std::next(I) != IE && std::next(I)->first == I->first) ) {
(std::next(I) != IE && std::next(I)->first == I->first)) {
Deferred.push_back(WeakTrackingVH(I->second));
}
}
Expand Down Expand Up @@ -467,9 +488,16 @@ bool MergeFunctions::runOnModule(Module &M) {
return Changed;
}

DenseMap<Function *, Function *>
MergeFunctions::runOnFunctions(ArrayRef<Function *> F) {
[[maybe_unused]] bool MergeResult = this->run(F);
assert(MergeResult == !DelToNewMap.empty());
return this->DelToNewMap;
}

// Replace direct callers of Old with New.
void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
for (Use &U : llvm::make_early_inc_range(Old->uses())) {
for (Use &U : make_early_inc_range(Old->uses())) {
CallBase *CB = dyn_cast<CallBase>(U.getUser());
if (CB && CB->isCallee(&U)) {
// Do not copy attributes from the called function to the call-site.
Expand Down Expand Up @@ -768,8 +796,8 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
ReturnInst *RI = nullptr;
bool isSwiftTailCall = F->getCallingConv() == CallingConv::SwiftTail &&
G->getCallingConv() == CallingConv::SwiftTail;
CI->setTailCallKind(isSwiftTailCall ? llvm::CallInst::TCK_MustTail
: llvm::CallInst::TCK_Tail);
CI->setTailCallKind(isSwiftTailCall ? CallInst::TCK_MustTail
: CallInst::TCK_Tail);
CI->setCallingConv(F->getCallingConv());
CI->setAttributes(F->getAttributes());
if (H->getReturnType()->isVoidTy()) {
Expand Down Expand Up @@ -1003,6 +1031,7 @@ bool MergeFunctions::insert(Function *NewFunction) {

Function *DeleteF = NewFunction;
mergeTwoFunctions(OldF.getFunc(), DeleteF);
this->DelToNewMap.insert({DeleteF, OldF.getFunc()});
return true;
}

Expand Down
1 change: 1 addition & 0 deletions llvm/unittests/Transforms/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ add_llvm_unittest(UtilsTests
LoopUtilsTest.cpp
MemTransferLowering.cpp
ModuleUtilsTest.cpp
MergeFunctionsTest.cpp
ScalarEvolutionExpanderTest.cpp
SizeOptsTest.cpp
SSAUpdaterBulkTest.cpp
Expand Down
246 changes: 246 additions & 0 deletions llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
//===- MergeFunctionsTest.cpp - Unit tests for MergeFunctionsPass ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/IPO/MergeFunctions.h"

#include "llvm/ADT/SetVector.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/SourceMgr.h"
#include "gtest/gtest.h"
#include <memory>

using namespace llvm;

namespace {

TEST(MergeFunctions, TrueOutputModuleTest) {
LLVMContext Ctx;
SMDiagnostic Err;
std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
@.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
@.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1

define dso_local i32 @f(i32 noundef %arg) {
entry:
%add109 = call i32 @_slice_add10(i32 %arg)
%call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
ret i32 %add109
}

declare i32 @printf(ptr noundef, ...)

define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
entry:
%add99 = call i32 @_slice_add10(i32 %argc)
%call = call i32 @f(i32 noundef 2)
%sub = sub nsw i32 %call, 6
%call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
ret i32 %add99
}

define internal i32 @_slice_add10(i32 %arg) {
sliceclone_entry:
%0 = mul nsw i32 %arg, %arg
%1 = mul nsw i32 %0, 2
%2 = mul nsw i32 %1, 2
%3 = mul nsw i32 %2, 2
%4 = add nsw i32 %3, 2
ret i32 %4
}

define internal i32 @_slice_add10_alt(i32 %arg) {
sliceclone_entry:
%0 = mul nsw i32 %arg, %arg
%1 = mul nsw i32 %0, 2
%2 = mul nsw i32 %1, 2
%3 = mul nsw i32 %2, 2
%4 = add nsw i32 %3, 2
ret i32 %4
}
)invalid",
Err, Ctx));

// Expects true after merging _slice_add10 and _slice_add10_alt
EXPECT_TRUE(MergeFunctionsPass::runOnModule(*M));
}

TEST(MergeFunctions, TrueOutputFunctionsTest) {
LLVMContext Ctx;
SMDiagnostic Err;
std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
@.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
@.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1

define dso_local i32 @f(i32 noundef %arg) {
entry:
%add109 = call i32 @_slice_add10(i32 %arg)
%call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
ret i32 %add109
}

declare i32 @printf(ptr noundef, ...)

define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
entry:
%add99 = call i32 @_slice_add10(i32 %argc)
%call = call i32 @f(i32 noundef 2)
%sub = sub nsw i32 %call, 6
%call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
ret i32 %add99
}

define internal i32 @_slice_add10(i32 %arg) {
sliceclone_entry:
%0 = mul nsw i32 %arg, %arg
%1 = mul nsw i32 %0, 2
%2 = mul nsw i32 %1, 2
%3 = mul nsw i32 %2, 2
%4 = add nsw i32 %3, 2
ret i32 %4
}

define internal i32 @_slice_add10_alt(i32 %arg) {
sliceclone_entry:
%0 = mul nsw i32 %arg, %arg
%1 = mul nsw i32 %0, 2
%2 = mul nsw i32 %1, 2
%3 = mul nsw i32 %2, 2
%4 = add nsw i32 %3, 2
ret i32 %4
}
)invalid",
Err, Ctx));

SetVector<Function *> FunctionsSet;
for (Function &F : *M)
FunctionsSet.insert(&F);

DenseMap<Function *, Function *> MergeResult =
MergeFunctionsPass::runOnFunctions(FunctionsSet.getArrayRef());

// Expects that both functions (_slice_add10 and _slice_add10_alt)
// be mapped to the same new function
EXPECT_TRUE(!MergeResult.empty());
Function *NewFunction = M->getFunction("_slice_add10");
for (auto P : MergeResult)
if (P.second)
EXPECT_EQ(P.second, NewFunction);
}

TEST(MergeFunctions, FalseOutputModuleTest) {
LLVMContext Ctx;
SMDiagnostic Err;
std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
@.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
@.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1

define dso_local i32 @f(i32 noundef %arg) {
entry:
%add109 = call i32 @_slice_add10(i32 %arg)
%call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
ret i32 %add109
}

declare i32 @printf(ptr noundef, ...)

define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
entry:
%add99 = call i32 @_slice_add10(i32 %argc)
%call = call i32 @f(i32 noundef 2)
%sub = sub nsw i32 %call, 6
%call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
ret i32 %add99
}

define internal i32 @_slice_add10(i32 %arg) {
sliceclone_entry:
%0 = mul nsw i32 %arg, %arg
%1 = mul nsw i32 %0, 2
%2 = mul nsw i32 %1, 2
%3 = mul nsw i32 %2, 2
%4 = add nsw i32 %3, 2
ret i32 %4
}

define internal i32 @_slice_add10_alt(i32 %arg) {
sliceclone_entry:
%0 = mul nsw i32 %arg, %arg
%1 = mul nsw i32 %0, 2
%2 = mul nsw i32 %1, 2
%3 = mul nsw i32 %2, 2
%4 = add nsw i32 %3, 2
ret i32 %0
}
)invalid",
Err, Ctx));

// Expects false after trying to merge _slice_add10 and _slice_add10_alt
EXPECT_FALSE(MergeFunctionsPass::runOnModule(*M));
}

TEST(MergeFunctions, FalseOutputFunctionsTest) {
LLVMContext Ctx;
SMDiagnostic Err;
std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
@.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
@.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1

define dso_local i32 @f(i32 noundef %arg) {
entry:
%add109 = call i32 @_slice_add10(i32 %arg)
%call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
ret i32 %add109
}

declare i32 @printf(ptr noundef, ...)

define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
entry:
%add99 = call i32 @_slice_add10(i32 %argc)
%call = call i32 @f(i32 noundef 2)
%sub = sub nsw i32 %call, 6
%call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
ret i32 %add99
}

define internal i32 @_slice_add10(i32 %arg) {
sliceclone_entry:
%0 = mul nsw i32 %arg, %arg
%1 = mul nsw i32 %0, 2
%2 = mul nsw i32 %1, 2
%3 = mul nsw i32 %2, 2
%4 = add nsw i32 %3, 2
ret i32 %4
}

define internal i32 @_slice_add10_alt(i32 %arg) {
sliceclone_entry:
%0 = mul nsw i32 %arg, %arg
%1 = mul nsw i32 %0, 2
%2 = mul nsw i32 %1, 2
%3 = mul nsw i32 %2, 2
%4 = add nsw i32 %3, 2
ret i32 %0
}
)invalid",
Err, Ctx));

SetVector<Function *> FunctionsSet;
for (Function &F : *M)
FunctionsSet.insert(&F);

DenseMap<Function *, Function *> MergeResult =
MergeFunctionsPass::runOnFunctions(FunctionsSet.getArrayRef());

// Expects empty map
EXPECT_EQ(MergeResult.size(), 0u);
}

} // namespace
Loading