Skip to content

Commit 4791659

Browse files
committed
[Attributor] Convert out arguments into a struct return
1 parent 6b545c2 commit 4791659

File tree

4 files changed

+294
-5
lines changed

4 files changed

+294
-5
lines changed

llvm/include/llvm/Transforms/IPO/Attributor.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6477,6 +6477,39 @@ struct AADenormalFPMath
64776477
static const char ID;
64786478
};
64796479

6480+
/// An abstract attribute for converting out arguments into struct elements.
6481+
struct AAConvertOutArgument
6482+
: public StateWrapper<BooleanState, AbstractAttribute> {
6483+
using Base = StateWrapper<BooleanState, AbstractAttribute>;
6484+
6485+
AAConvertOutArgument(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
6486+
6487+
/// Create an abstract attribute view for the position \p IRP.
6488+
static AAConvertOutArgument &createForPosition(const IRPosition &IRP,
6489+
Attributor &A);
6490+
6491+
/// See AbstractAttribute::getName()
6492+
const std::string getName() const override { return "AAConvertOutArgument"; }
6493+
6494+
/// Return true if convertible is assumed.
6495+
bool isAssumedConvertible() const { return getAssumed(); }
6496+
6497+
/// Return true if convertible is known.
6498+
bool isKnownConvertible() const { return getKnown(); }
6499+
6500+
/// See AbstractAttribute::getIdAddr()
6501+
const char *getIdAddr() const override { return &ID; }
6502+
6503+
/// This function should return true if the type of the \p AA is
6504+
/// AADenormalFPMath.
6505+
static bool classof(const AbstractAttribute *AA) {
6506+
return (AA->getIdAddr() == &ID);
6507+
}
6508+
6509+
/// Unique ID (due to the unique address)
6510+
static const char ID;
6511+
};
6512+
64806513
raw_ostream &operator<<(raw_ostream &, const AAPointerInfo::Access &);
64816514

64826515
/// Run options, used by the pass manager.

llvm/lib/Transforms/IPO/Attributor.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3459,6 +3459,7 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
34593459
}
34603460
}
34613461

3462+
bool markedAsAAConvertArgument = false;
34623463
for (Argument &Arg : F.args()) {
34633464
IRPosition ArgPos = IRPosition::argument(Arg);
34643465
auto ArgNo = Arg.getArgNo();
@@ -3510,6 +3511,12 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
35103511
// Every argument with pointer type might be privatizable (or
35113512
// promotable)
35123513
getOrCreateAAFor<AAPrivatizablePtr>(ArgPos);
3514+
3515+
// Every function with pointer argument type can have out arguments.
3516+
if (!markedAsAAConvertArgument) {
3517+
getOrCreateAAFor<AAConvertOutArgument>(FPos);
3518+
markedAsAAConvertArgument = true;
3519+
}
35133520
} else if (AttributeFuncs::isNoFPClassCompatibleType(Arg.getType())) {
35143521
getOrCreateAAFor<AANoFPClass>(ArgPos);
35153522
}

llvm/lib/Transforms/IPO/AttributorAttributes.cpp

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
#include "llvm/Support/raw_ostream.h"
6969
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
7070
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
71+
#include "llvm/Transforms/Utils/Cloning.h"
7172
#include "llvm/Transforms/Utils/Local.h"
7273
#include "llvm/Transforms/Utils/ValueMapper.h"
7374
#include <cassert>
@@ -197,6 +198,7 @@ PIPE_OPERATOR(AAAllocationInfo)
197198
PIPE_OPERATOR(AAIndirectCallInfo)
198199
PIPE_OPERATOR(AAGlobalValueInfo)
199200
PIPE_OPERATOR(AADenormalFPMath)
201+
PIPE_OPERATOR(AAConvertOutArgument)
200202

201203
#undef PIPE_OPERATOR
202204

@@ -12987,6 +12989,225 @@ struct AAAllocationInfoCallSiteArgument : AAAllocationInfoImpl {
1298712989
};
1298812990
} // namespace
1298912991

12992+
/// ----------- AAConvertOutArgument ----------
12993+
namespace {
12994+
static bool isEligibleArgument(const Argument &Arg, Attributor &A,
12995+
const AbstractAttribute &AA) {
12996+
if (!Arg.getType()->isPointerTy())
12997+
return false;
12998+
12999+
const IRPosition &ArgPos = IRPosition::argument(Arg);
13000+
auto *AAMem = A.getAAFor<AAMemoryBehavior>(AA, ArgPos, DepClassTy::OPTIONAL);
13001+
auto *NoAlias = A.getAAFor<AANoAlias>(AA, ArgPos, DepClassTy::OPTIONAL);
13002+
13003+
return AAMem && NoAlias && AAMem->isAssumedWriteOnly() &&
13004+
NoAlias->isAssumedNoAlias() && !Arg.hasPointeeInMemoryValueAttr();
13005+
}
13006+
13007+
struct AAConvertOutArgumentFunction final : AAConvertOutArgument {
13008+
AAConvertOutArgumentFunction(const IRPosition &IRP, Attributor &A)
13009+
: AAConvertOutArgument(IRP, A) {}
13010+
13011+
SmallVector<bool> ArgumentsStates;
13012+
13013+
/// See AbstractAttribute::updateImpl(...).
13014+
void initialize(Attributor &A) override {
13015+
const Function *F = getAssociatedFunction();
13016+
if (!F || F->isDeclaration())
13017+
return;
13018+
13019+
// Assume that all args are convertable at the begining.
13020+
ArgumentsStates.resize(F->arg_size(), true);
13021+
}
13022+
13023+
/// See AbstractAttribute::updateImpl(...).
13024+
ChangeStatus updateImpl(Attributor &A) override {
13025+
const Function *F = getAssociatedFunction();
13026+
if (!F || F->isDeclaration())
13027+
return indicatePessimisticFixpoint();
13028+
13029+
auto NewStates = ArgumentsStates;
13030+
for (unsigned ArgIdx = 0; ArgIdx < F->arg_size(); ++ArgIdx)
13031+
if (!isEligibleArgument(*F->getArg(ArgIdx), A, *this))
13032+
NewStates[ArgIdx] = false;
13033+
13034+
bool Changed = NewStates == ArgumentsStates;
13035+
ArgumentsStates = NewStates;
13036+
return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
13037+
}
13038+
13039+
/// See AbstractAttribute::manifest(...).
13040+
ChangeStatus manifest(Attributor &A) override {
13041+
Function &F = *getAssociatedFunction();
13042+
DenseMap<Argument *, Type *> PtrToType;
13043+
SmallVector<Argument *, 4> CandidateArgs;
13044+
13045+
for (unsigned ArgIdx = 0; ArgIdx < F.arg_size(); ++ArgIdx) {
13046+
Argument *Arg = F.getArg(ArgIdx);
13047+
if (!isEligibleArgument(*Arg, A, *this))
13048+
continue;
13049+
13050+
CandidateArgs.push_back(Arg);
13051+
// AAPointerInfo on args
13052+
for (auto &Use : Arg->uses())
13053+
if (auto *Store = dyn_cast<StoreInst>(Use.getUser()))
13054+
PtrToType[Arg] = Store->getValueOperand()->getType();
13055+
}
13056+
13057+
// If there is no valid candidates then return false.
13058+
if (PtrToType.empty())
13059+
return indicatePessimisticFixpoint();
13060+
13061+
// Create the new struct return type.
13062+
SmallVector<Type *, 4> OutStructElementsTypes;
13063+
if (auto *OriginalFuncTy = F.getReturnType(); !OriginalFuncTy->isVoidTy())
13064+
OutStructElementsTypes.push_back(OriginalFuncTy);
13065+
13066+
for (auto *Arg : CandidateArgs)
13067+
OutStructElementsTypes.push_back(PtrToType[Arg]);
13068+
13069+
auto *ReturnStructType = StructType::create(
13070+
F.getContext(), OutStructElementsTypes, (F.getName() + "_out").str());
13071+
13072+
// Get the new Args.
13073+
SmallVector<Type *, 4> NewParamTypes;
13074+
for (auto &Arg : F.args())
13075+
if (!PtrToType.count(&Arg))
13076+
NewParamTypes.push_back(Arg.getType());
13077+
13078+
auto *NewFunctionType =
13079+
FunctionType::get(ReturnStructType, NewParamTypes, F.isVarArg());
13080+
auto *NewFunction =
13081+
Function::Create(NewFunctionType, F.getLinkage(), F.getAddressSpace(),
13082+
F.getName() + ".converted");
13083+
13084+
// Map old arguments to new ones, And also map the old arguments to struct
13085+
// elements.
13086+
ValueToValueMapTy VMap;
13087+
auto NewArgIt = NewFunction->arg_begin();
13088+
BasicBlock *EntryBlock =
13089+
BasicBlock::Create(NewFunction->getContext(), "entry", NewFunction);
13090+
13091+
IRBuilder<> EntryBuilder(EntryBlock);
13092+
for (auto &OldArg : F.args()) {
13093+
if (PtrToType.count(&OldArg)) {
13094+
dbgs() << "OldArg: " << OldArg
13095+
<< " ======> Type: " << *PtrToType[&OldArg] << "\n";
13096+
AllocaInst *Alloca = EntryBuilder.CreateAlloca(
13097+
PtrToType[&OldArg], nullptr, OldArg.getName() + "_");
13098+
VMap[&OldArg] = Alloca;
13099+
} else
13100+
VMap[&OldArg] = &(*NewArgIt++);
13101+
}
13102+
13103+
// Clone the old function into the new one.
13104+
SmallVector<ReturnInst *, 8> Returns;
13105+
CloneFunctionInto(NewFunction, &F, VMap,
13106+
CloneFunctionChangeType::LocalChangesOnly, Returns);
13107+
13108+
// Update the return values (make it struct).
13109+
for (ReturnInst *Ret : Returns) {
13110+
IRBuilder<> Builder(Ret);
13111+
SmallVector<Value *, 4> StructValues;
13112+
// Include original return type, if any
13113+
if (auto *OriginalFuncTy = F.getReturnType(); !OriginalFuncTy->isVoidTy())
13114+
StructValues.push_back(Ret->getReturnValue());
13115+
13116+
// Create a load instruction to fill the struct element.
13117+
for (auto *Arg : CandidateArgs) {
13118+
Value *OutVal = Builder.CreateLoad(PtrToType[Arg], VMap[Arg]);
13119+
StructValues.push_back(OutVal);
13120+
}
13121+
13122+
// Build the return struct incrementally.
13123+
Value *StructRetVal = UndefValue::get(ReturnStructType);
13124+
for (unsigned i = 0; i < StructValues.size(); ++i)
13125+
StructRetVal =
13126+
Builder.CreateInsertValue(StructRetVal, StructValues[i], i);
13127+
13128+
Builder.CreateRet(StructRetVal);
13129+
A.deleteAfterManifest(*Ret);
13130+
}
13131+
return ChangeStatus::CHANGED;
13132+
}
13133+
13134+
/// See AbstractAttribute::getAsStr(...).
13135+
const std::string getAsStr(Attributor *A) const override {
13136+
return "AAConvertOutArgumentFunction";
13137+
}
13138+
13139+
/// See AbstractAttribute::trackStatistics()
13140+
void trackStatistics() const override {}
13141+
};
13142+
13143+
struct AAConvertOutArgumentCallSite final : AAConvertOutArgument {
13144+
AAConvertOutArgumentCallSite(const IRPosition &IRP, Attributor &A)
13145+
: AAConvertOutArgument(IRP, A) {}
13146+
13147+
/// See AbstractAttribute::updateImpl(...).
13148+
ChangeStatus updateImpl(Attributor &A) override {
13149+
CallBase *CB = cast<CallBase>(getCtxI());
13150+
Function *F = CB->getCalledFunction();
13151+
if (!F)
13152+
return indicatePessimisticFixpoint();
13153+
13154+
// Get convert attribute.
13155+
auto *ConvertAA = A.getAAFor<AAConvertOutArgument>(
13156+
*this, IRPosition::function(*F), DepClassTy::REQUIRED);
13157+
13158+
// If function will be transformed, mark this call site for update
13159+
if (!ConvertAA || ConvertAA->isAssumedConvertible())
13160+
return ChangeStatus::CHANGED;
13161+
13162+
return ChangeStatus::UNCHANGED;
13163+
}
13164+
13165+
/// See AbstractAttribute::manifest(...).
13166+
ChangeStatus manifest(Attributor &A) override {
13167+
CallBase *CB = cast<CallBase>(getCtxI());
13168+
Function *F = CB->getCalledFunction();
13169+
if (!F)
13170+
return ChangeStatus::UNCHANGED;
13171+
13172+
IRBuilder<> Builder(CB);
13173+
// Create args for new call.
13174+
SmallVector<Value *, 4> NewArgs;
13175+
for (unsigned ArgIdx = 0; ArgIdx < CB->arg_size(); ++ArgIdx) {
13176+
Value *Arg = CB->getArgOperand(ArgIdx);
13177+
Argument *ParamArg = F->getArg(ArgIdx);
13178+
if (!isEligibleArgument(*ParamArg, A, *this))
13179+
NewArgs.push_back(Arg);
13180+
}
13181+
13182+
Module *M = F->getParent();
13183+
auto *NewF = M->getFunction((F->getName() + ".converted").str());
13184+
if (!NewF)
13185+
return ChangeStatus::UNCHANGED;
13186+
13187+
FunctionCallee NewCallee(NewF->getFunctionType(), NewF);
13188+
Instruction *NewCall =
13189+
CallInst::Create(NewCallee, NewArgs, CB->getName() + ".converted", CB);
13190+
IRPosition ReturnPos = IRPosition::callsite_returned(*CB);
13191+
A.changeAfterManifest(ReturnPos, *NewCall);
13192+
13193+
// Redirect all uses of the old call to the new call.
13194+
for (auto &Use : CB->uses())
13195+
Use.set(NewCall);
13196+
13197+
A.deleteAfterManifest(*CB);
13198+
return ChangeStatus::CHANGED;
13199+
}
13200+
13201+
/// See AbstractAttribute::getAsStr(...).
13202+
const std::string getAsStr(Attributor *A) const override {
13203+
return "AAConvertOutArgumentCallSite";
13204+
}
13205+
13206+
/// See AbstractAttribute::trackStatistics()
13207+
void trackStatistics() const override {}
13208+
};
13209+
} // namespace
13210+
1299013211
const char AANoUnwind::ID = 0;
1299113212
const char AANoSync::ID = 0;
1299213213
const char AANoFree::ID = 0;
@@ -13024,6 +13245,7 @@ const char AAAllocationInfo::ID = 0;
1302413245
const char AAIndirectCallInfo::ID = 0;
1302513246
const char AAGlobalValueInfo::ID = 0;
1302613247
const char AADenormalFPMath::ID = 0;
13248+
const char AAConvertOutArgument::ID = 0;
1302713249

1302813250
// Macro magic to create the static generator function for attributes that
1302913251
// follow the naming scheme.
@@ -13139,6 +13361,7 @@ CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryLocation)
1313913361
CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AACallEdges)
1314013362
CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAssumptionInfo)
1314113363
CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMustProgress)
13364+
CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAConvertOutArgument)
1314213365

1314313366
CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonNull)
1314413367
CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoAlias)

llvm/test/Transforms/Attributor/remove_out_args.ll

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,44 @@
33

44

55

6-
define internal i1 @foo(ptr %dst) {
6+
define internal i1 @foo(ptr %dst, i32 %a, i32 %b) {
77
entry:
8-
store i32 42, ptr %dst
9-
ret i1 true
8+
%x = xor i32 %a, 13
9+
%y = add i32 %b, 5
10+
%z = icmp sle i32 %x, %y
11+
br i1 %z, label %if, label %else
12+
13+
if:
14+
store i32 %x, ptr %dst, align 4
15+
br label %end
16+
17+
else:
18+
store i32 %y, ptr %dst, align 4
19+
br label %end
20+
21+
end:
22+
%t = mul i32 %x, %y
23+
%tt = xor i32 %x, %y
24+
%result = icmp eq i32 %t, %tt
25+
ret i1 %result
1026
}
1127

1228

13-
define i1 @fee(i32 %x, i32 %y) {
29+
define i1 @fee(i32 %x, i32 %y, i32 %z) {
30+
; CHECK-LABEL: define i1 @fee(
31+
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) #[[ATTR0:[0-9]+]] {
32+
; CHECK-NEXT: [[PTR:%.*]] = alloca i32, align 4
33+
; CHECK-NEXT: [[A:%.*]] = call i1 @foo.converted(ptr noalias nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[PTR]], i32 [[Y]], i32 [[Z]]) #[[ATTR1:[0-9]+]]
34+
; CHECK-NEXT: [[B:%.*]] = load i32, ptr [[PTR]], align 4
35+
; CHECK-NEXT: [[C:%.*]] = icmp sle i32 [[B]], [[X]]
36+
; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[A]], [[C]]
37+
; CHECK-NEXT: ret i1 [[XOR]]
38+
;
1439
%ptr = alloca i32
15-
%a = call i1 @foo(ptr %ptr, i32 %y)
40+
%a = call i1 @foo(ptr %ptr, i32 %y, i32 %z)
1641
%b = load i32, ptr %ptr
1742
%c = icmp sle i32 %b, %x
1843
%xor = xor i1 %a, %c
1944
ret i1 %xor
2045
}
46+

0 commit comments

Comments
 (0)