Skip to content

Commit 2a6e589

Browse files
authored
[MergeFunctions] Add support to run the pass over a set of function pointers (#111045)
This modification will enable the usage of `MergeFunctions` as a standalone library. Currently, `MergeFunctions` can only be applied to an entire module. By adopting this change, developers will gain the flexibility to reuse the `MergeFunctions` code within their own projects, choosing which functions to merge; hence, promoting code reusability. Notice that this modification will not break backward compatibility, because `MergeFunctions` will still work as a pass after the modification.
1 parent 3a01b46 commit 2a6e589

File tree

5 files changed

+298
-16
lines changed

5 files changed

+298
-16
lines changed

llvm/include/llvm/Transforms/IPO/MergeFunctions.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,16 @@
2020
namespace llvm {
2121

2222
class Module;
23+
class Function;
2324

2425
/// Merge identical functions.
2526
class MergeFunctionsPass : public PassInfoMixin<MergeFunctionsPass> {
2627
public:
2728
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
29+
30+
static bool runOnModule(Module &M);
31+
static DenseMap<Function *, Function *>
32+
runOnFunctions(ArrayRef<Function *> F);
2833
};
2934

3035
} // end namespace llvm

llvm/lib/Transforms/IPO/MergeFunctions.cpp

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,10 @@ class MergeFunctions {
196196
MergeFunctions() : FnTree(FunctionNodeCmp(&GlobalNumbers)) {
197197
}
198198

199-
bool runOnModule(Module &M);
199+
template <typename FuncContainer> bool run(FuncContainer &Functions);
200+
DenseMap<Function *, Function *> runOnFunctions(ArrayRef<Function *> F);
201+
202+
SmallPtrSet<GlobalValue *, 4> &getUsed();
200203

201204
private:
202205
// The function comparison operator is provided here so that FunctionNodes do
@@ -297,17 +300,36 @@ class MergeFunctions {
297300
// dangling iterators into FnTree. The invariant that preserves this is that
298301
// there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
299302
DenseMap<AssertingVH<Function>, FnTreeType::iterator> FNodesInTree;
303+
304+
/// Deleted-New functions mapping
305+
DenseMap<Function *, Function *> DelToNewMap;
300306
};
301307
} // end anonymous namespace
302308

303309
PreservedAnalyses MergeFunctionsPass::run(Module &M,
304310
ModuleAnalysisManager &AM) {
305-
MergeFunctions MF;
306-
if (!MF.runOnModule(M))
311+
if (!MergeFunctionsPass::runOnModule(M))
307312
return PreservedAnalyses::all();
308313
return PreservedAnalyses::none();
309314
}
310315

316+
SmallPtrSet<GlobalValue *, 4> &MergeFunctions::getUsed() { return Used; }
317+
318+
bool MergeFunctionsPass::runOnModule(Module &M) {
319+
MergeFunctions MF;
320+
SmallVector<GlobalValue *, 4> UsedV;
321+
collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/false);
322+
collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/true);
323+
MF.getUsed().insert(UsedV.begin(), UsedV.end());
324+
return MF.run(M);
325+
}
326+
327+
DenseMap<Function *, Function *>
328+
MergeFunctionsPass::runOnFunctions(ArrayRef<Function *> F) {
329+
MergeFunctions MF;
330+
return MF.runOnFunctions(F);
331+
}
332+
311333
#ifndef NDEBUG
312334
bool MergeFunctions::doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist) {
313335
if (const unsigned Max = NumFunctionsForVerificationCheck) {
@@ -409,20 +431,19 @@ static bool isEligibleForMerging(Function &F) {
409431
!hasDistinctMetadataIntrinsic(F);
410432
}
411433

412-
bool MergeFunctions::runOnModule(Module &M) {
413-
bool Changed = false;
434+
inline Function *asPtr(Function *Fn) { return Fn; }
435+
inline Function *asPtr(Function &Fn) { return &Fn; }
414436

415-
SmallVector<GlobalValue *, 4> UsedV;
416-
collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/false);
417-
collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/true);
418-
Used.insert(UsedV.begin(), UsedV.end());
437+
template <typename FuncContainer> bool MergeFunctions::run(FuncContainer &M) {
438+
bool Changed = false;
419439

420440
// All functions in the module, ordered by hash. Functions with a unique
421441
// hash value are easily eliminated.
422442
std::vector<std::pair<stable_hash, Function *>> HashedFuncs;
423-
for (Function &Func : M) {
424-
if (isEligibleForMerging(Func)) {
425-
HashedFuncs.push_back({StructuralHash(Func), &Func});
443+
for (auto &Func : M) {
444+
Function *FuncPtr = asPtr(Func);
445+
if (isEligibleForMerging(*FuncPtr)) {
446+
HashedFuncs.push_back({StructuralHash(*FuncPtr), FuncPtr});
426447
}
427448
}
428449

@@ -433,7 +454,7 @@ bool MergeFunctions::runOnModule(Module &M) {
433454
// If the hash value matches the previous value or the next one, we must
434455
// consider merging it. Otherwise it is dropped and never considered again.
435456
if ((I != S && std::prev(I)->first == I->first) ||
436-
(std::next(I) != IE && std::next(I)->first == I->first) ) {
457+
(std::next(I) != IE && std::next(I)->first == I->first)) {
437458
Deferred.push_back(WeakTrackingVH(I->second));
438459
}
439460
}
@@ -467,9 +488,16 @@ bool MergeFunctions::runOnModule(Module &M) {
467488
return Changed;
468489
}
469490

491+
DenseMap<Function *, Function *>
492+
MergeFunctions::runOnFunctions(ArrayRef<Function *> F) {
493+
[[maybe_unused]] bool MergeResult = this->run(F);
494+
assert(MergeResult == !DelToNewMap.empty());
495+
return this->DelToNewMap;
496+
}
497+
470498
// Replace direct callers of Old with New.
471499
void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
472-
for (Use &U : llvm::make_early_inc_range(Old->uses())) {
500+
for (Use &U : make_early_inc_range(Old->uses())) {
473501
CallBase *CB = dyn_cast<CallBase>(U.getUser());
474502
if (CB && CB->isCallee(&U)) {
475503
// Do not copy attributes from the called function to the call-site.
@@ -768,8 +796,8 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
768796
ReturnInst *RI = nullptr;
769797
bool isSwiftTailCall = F->getCallingConv() == CallingConv::SwiftTail &&
770798
G->getCallingConv() == CallingConv::SwiftTail;
771-
CI->setTailCallKind(isSwiftTailCall ? llvm::CallInst::TCK_MustTail
772-
: llvm::CallInst::TCK_Tail);
799+
CI->setTailCallKind(isSwiftTailCall ? CallInst::TCK_MustTail
800+
: CallInst::TCK_Tail);
773801
CI->setCallingConv(F->getCallingConv());
774802
CI->setAttributes(F->getAttributes());
775803
if (H->getReturnType()->isVoidTy()) {
@@ -1003,6 +1031,7 @@ bool MergeFunctions::insert(Function *NewFunction) {
10031031

10041032
Function *DeleteF = NewFunction;
10051033
mergeTwoFunctions(OldF.getFunc(), DeleteF);
1034+
this->DelToNewMap.insert({DeleteF, OldF.getFunc()});
10061035
return true;
10071036
}
10081037

llvm/unittests/Transforms/Utils/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_llvm_unittest(UtilsTests
2626
LoopUtilsTest.cpp
2727
MemTransferLowering.cpp
2828
ModuleUtilsTest.cpp
29+
MergeFunctionsTest.cpp
2930
ScalarEvolutionExpanderTest.cpp
3031
SizeOptsTest.cpp
3132
SSAUpdaterBulkTest.cpp
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
//===- MergeFunctionsTest.cpp - Unit tests for MergeFunctionsPass ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Transforms/IPO/MergeFunctions.h"
10+
11+
#include "llvm/ADT/SetVector.h"
12+
#include "llvm/AsmParser/Parser.h"
13+
#include "llvm/IR/LLVMContext.h"
14+
#include "llvm/IR/Module.h"
15+
#include "llvm/Support/SourceMgr.h"
16+
#include "gtest/gtest.h"
17+
#include <memory>
18+
19+
using namespace llvm;
20+
21+
namespace {
22+
23+
TEST(MergeFunctions, TrueOutputModuleTest) {
24+
LLVMContext Ctx;
25+
SMDiagnostic Err;
26+
std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
27+
@.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
28+
@.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
29+
30+
define dso_local i32 @f(i32 noundef %arg) {
31+
entry:
32+
%add109 = call i32 @_slice_add10(i32 %arg)
33+
%call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
34+
ret i32 %add109
35+
}
36+
37+
declare i32 @printf(ptr noundef, ...)
38+
39+
define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
40+
entry:
41+
%add99 = call i32 @_slice_add10(i32 %argc)
42+
%call = call i32 @f(i32 noundef 2)
43+
%sub = sub nsw i32 %call, 6
44+
%call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
45+
ret i32 %add99
46+
}
47+
48+
define internal i32 @_slice_add10(i32 %arg) {
49+
sliceclone_entry:
50+
%0 = mul nsw i32 %arg, %arg
51+
%1 = mul nsw i32 %0, 2
52+
%2 = mul nsw i32 %1, 2
53+
%3 = mul nsw i32 %2, 2
54+
%4 = add nsw i32 %3, 2
55+
ret i32 %4
56+
}
57+
58+
define internal i32 @_slice_add10_alt(i32 %arg) {
59+
sliceclone_entry:
60+
%0 = mul nsw i32 %arg, %arg
61+
%1 = mul nsw i32 %0, 2
62+
%2 = mul nsw i32 %1, 2
63+
%3 = mul nsw i32 %2, 2
64+
%4 = add nsw i32 %3, 2
65+
ret i32 %4
66+
}
67+
)invalid",
68+
Err, Ctx));
69+
70+
// Expects true after merging _slice_add10 and _slice_add10_alt
71+
EXPECT_TRUE(MergeFunctionsPass::runOnModule(*M));
72+
}
73+
74+
TEST(MergeFunctions, TrueOutputFunctionsTest) {
75+
LLVMContext Ctx;
76+
SMDiagnostic Err;
77+
std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
78+
@.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
79+
@.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
80+
81+
define dso_local i32 @f(i32 noundef %arg) {
82+
entry:
83+
%add109 = call i32 @_slice_add10(i32 %arg)
84+
%call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
85+
ret i32 %add109
86+
}
87+
88+
declare i32 @printf(ptr noundef, ...)
89+
90+
define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
91+
entry:
92+
%add99 = call i32 @_slice_add10(i32 %argc)
93+
%call = call i32 @f(i32 noundef 2)
94+
%sub = sub nsw i32 %call, 6
95+
%call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
96+
ret i32 %add99
97+
}
98+
99+
define internal i32 @_slice_add10(i32 %arg) {
100+
sliceclone_entry:
101+
%0 = mul nsw i32 %arg, %arg
102+
%1 = mul nsw i32 %0, 2
103+
%2 = mul nsw i32 %1, 2
104+
%3 = mul nsw i32 %2, 2
105+
%4 = add nsw i32 %3, 2
106+
ret i32 %4
107+
}
108+
109+
define internal i32 @_slice_add10_alt(i32 %arg) {
110+
sliceclone_entry:
111+
%0 = mul nsw i32 %arg, %arg
112+
%1 = mul nsw i32 %0, 2
113+
%2 = mul nsw i32 %1, 2
114+
%3 = mul nsw i32 %2, 2
115+
%4 = add nsw i32 %3, 2
116+
ret i32 %4
117+
}
118+
)invalid",
119+
Err, Ctx));
120+
121+
SetVector<Function *> FunctionsSet;
122+
for (Function &F : *M)
123+
FunctionsSet.insert(&F);
124+
125+
DenseMap<Function *, Function *> MergeResult =
126+
MergeFunctionsPass::runOnFunctions(FunctionsSet.getArrayRef());
127+
128+
// Expects that both functions (_slice_add10 and _slice_add10_alt)
129+
// be mapped to the same new function
130+
EXPECT_TRUE(!MergeResult.empty());
131+
Function *NewFunction = M->getFunction("_slice_add10");
132+
for (auto P : MergeResult)
133+
if (P.second)
134+
EXPECT_EQ(P.second, NewFunction);
135+
}
136+
137+
TEST(MergeFunctions, FalseOutputModuleTest) {
138+
LLVMContext Ctx;
139+
SMDiagnostic Err;
140+
std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
141+
@.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
142+
@.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
143+
144+
define dso_local i32 @f(i32 noundef %arg) {
145+
entry:
146+
%add109 = call i32 @_slice_add10(i32 %arg)
147+
%call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
148+
ret i32 %add109
149+
}
150+
151+
declare i32 @printf(ptr noundef, ...)
152+
153+
define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
154+
entry:
155+
%add99 = call i32 @_slice_add10(i32 %argc)
156+
%call = call i32 @f(i32 noundef 2)
157+
%sub = sub nsw i32 %call, 6
158+
%call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
159+
ret i32 %add99
160+
}
161+
162+
define internal i32 @_slice_add10(i32 %arg) {
163+
sliceclone_entry:
164+
%0 = mul nsw i32 %arg, %arg
165+
%1 = mul nsw i32 %0, 2
166+
%2 = mul nsw i32 %1, 2
167+
%3 = mul nsw i32 %2, 2
168+
%4 = add nsw i32 %3, 2
169+
ret i32 %4
170+
}
171+
172+
define internal i32 @_slice_add10_alt(i32 %arg) {
173+
sliceclone_entry:
174+
%0 = mul nsw i32 %arg, %arg
175+
%1 = mul nsw i32 %0, 2
176+
%2 = mul nsw i32 %1, 2
177+
%3 = mul nsw i32 %2, 2
178+
%4 = add nsw i32 %3, 2
179+
ret i32 %0
180+
}
181+
)invalid",
182+
Err, Ctx));
183+
184+
// Expects false after trying to merge _slice_add10 and _slice_add10_alt
185+
EXPECT_FALSE(MergeFunctionsPass::runOnModule(*M));
186+
}
187+
188+
TEST(MergeFunctions, FalseOutputFunctionsTest) {
189+
LLVMContext Ctx;
190+
SMDiagnostic Err;
191+
std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
192+
@.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
193+
@.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
194+
195+
define dso_local i32 @f(i32 noundef %arg) {
196+
entry:
197+
%add109 = call i32 @_slice_add10(i32 %arg)
198+
%call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
199+
ret i32 %add109
200+
}
201+
202+
declare i32 @printf(ptr noundef, ...)
203+
204+
define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
205+
entry:
206+
%add99 = call i32 @_slice_add10(i32 %argc)
207+
%call = call i32 @f(i32 noundef 2)
208+
%sub = sub nsw i32 %call, 6
209+
%call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
210+
ret i32 %add99
211+
}
212+
213+
define internal i32 @_slice_add10(i32 %arg) {
214+
sliceclone_entry:
215+
%0 = mul nsw i32 %arg, %arg
216+
%1 = mul nsw i32 %0, 2
217+
%2 = mul nsw i32 %1, 2
218+
%3 = mul nsw i32 %2, 2
219+
%4 = add nsw i32 %3, 2
220+
ret i32 %4
221+
}
222+
223+
define internal i32 @_slice_add10_alt(i32 %arg) {
224+
sliceclone_entry:
225+
%0 = mul nsw i32 %arg, %arg
226+
%1 = mul nsw i32 %0, 2
227+
%2 = mul nsw i32 %1, 2
228+
%3 = mul nsw i32 %2, 2
229+
%4 = add nsw i32 %3, 2
230+
ret i32 %0
231+
}
232+
)invalid",
233+
Err, Ctx));
234+
235+
SetVector<Function *> FunctionsSet;
236+
for (Function &F : *M)
237+
FunctionsSet.insert(&F);
238+
239+
DenseMap<Function *, Function *> MergeResult =
240+
MergeFunctionsPass::runOnFunctions(FunctionsSet.getArrayRef());
241+
242+
// Expects empty map
243+
EXPECT_EQ(MergeResult.size(), 0u);
244+
}
245+
246+
} // namespace

0 commit comments

Comments
 (0)