From 6b545c2ffe586082156fe53a8a6c5c010f80cfc1 Mon Sep 17 00:00:00 2001 From: Mohamed Atef Date: Mon, 9 Dec 2024 22:24:29 +0200 Subject: [PATCH 1/2] [Attributor] Add pre-commit tests --- .../Transforms/Attributor/remove_out_args.ll | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 llvm/test/Transforms/Attributor/remove_out_args.ll diff --git a/llvm/test/Transforms/Attributor/remove_out_args.ll b/llvm/test/Transforms/Attributor/remove_out_args.ll new file mode 100644 index 0000000000000..40c39ea41ff67 --- /dev/null +++ b/llvm/test/Transforms/Attributor/remove_out_args.ll @@ -0,0 +1,20 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -S -passes=attributor < %s | FileCheck %s + + + +define internal i1 @foo(ptr %dst) { +entry: + store i32 42, ptr %dst + ret i1 true +} + + +define i1 @fee(i32 %x, i32 %y) { + %ptr = alloca i32 + %a = call i1 @foo(ptr %ptr, i32 %y) + %b = load i32, ptr %ptr + %c = icmp sle i32 %b, %x + %xor = xor i1 %a, %c + ret i1 %xor +} From 479165962e72bb5a9e5ac31154f092f2f08bbd4e Mon Sep 17 00:00:00 2001 From: Mohamed Atef Date: Thu, 20 Feb 2025 00:20:37 +0200 Subject: [PATCH 2/2] [Attributor] Convert out arguments into a struct return --- llvm/include/llvm/Transforms/IPO/Attributor.h | 33 +++ llvm/lib/Transforms/IPO/Attributor.cpp | 7 + .../Transforms/IPO/AttributorAttributes.cpp | 223 ++++++++++++++++++ .../Transforms/Attributor/remove_out_args.ll | 36 ++- 4 files changed, 294 insertions(+), 5 deletions(-) diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h index 8589314699749..19a319a57d326 100644 --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -6477,6 +6477,39 @@ struct AADenormalFPMath static const char ID; }; +/// An abstract attribute for converting out arguments into struct elements. +struct AAConvertOutArgument + : public StateWrapper { + using Base = StateWrapper; + + AAConvertOutArgument(const IRPosition &IRP, Attributor &A) : Base(IRP) {} + + /// Create an abstract attribute view for the position \p IRP. + static AAConvertOutArgument &createForPosition(const IRPosition &IRP, + Attributor &A); + + /// See AbstractAttribute::getName() + const std::string getName() const override { return "AAConvertOutArgument"; } + + /// Return true if convertible is assumed. + bool isAssumedConvertible() const { return getAssumed(); } + + /// Return true if convertible is known. + bool isKnownConvertible() const { return getKnown(); } + + /// See AbstractAttribute::getIdAddr() + const char *getIdAddr() const override { return &ID; } + + /// This function should return true if the type of the \p AA is + /// AADenormalFPMath. + static bool classof(const AbstractAttribute *AA) { + return (AA->getIdAddr() == &ID); + } + + /// Unique ID (due to the unique address) + static const char ID; +}; + raw_ostream &operator<<(raw_ostream &, const AAPointerInfo::Access &); /// Run options, used by the pass manager. diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp index a93284926d684..373ce71afe358 100644 --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -3459,6 +3459,7 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { } } + bool markedAsAAConvertArgument = false; for (Argument &Arg : F.args()) { IRPosition ArgPos = IRPosition::argument(Arg); auto ArgNo = Arg.getArgNo(); @@ -3510,6 +3511,12 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { // Every argument with pointer type might be privatizable (or // promotable) getOrCreateAAFor(ArgPos); + + // Every function with pointer argument type can have out arguments. + if (!markedAsAAConvertArgument) { + getOrCreateAAFor(FPos); + markedAsAAConvertArgument = true; + } } else if (AttributeFuncs::isNoFPClassCompatibleType(Arg.getType())) { getOrCreateAAFor(ArgPos); } diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp index 58b8f1f779f72..ccce5ac7c418a 100644 --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -68,6 +68,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CallPromotionUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include @@ -197,6 +198,7 @@ PIPE_OPERATOR(AAAllocationInfo) PIPE_OPERATOR(AAIndirectCallInfo) PIPE_OPERATOR(AAGlobalValueInfo) PIPE_OPERATOR(AADenormalFPMath) +PIPE_OPERATOR(AAConvertOutArgument) #undef PIPE_OPERATOR @@ -12987,6 +12989,225 @@ struct AAAllocationInfoCallSiteArgument : AAAllocationInfoImpl { }; } // namespace +/// ----------- AAConvertOutArgument ---------- +namespace { +static bool isEligibleArgument(const Argument &Arg, Attributor &A, + const AbstractAttribute &AA) { + if (!Arg.getType()->isPointerTy()) + return false; + + const IRPosition &ArgPos = IRPosition::argument(Arg); + auto *AAMem = A.getAAFor(AA, ArgPos, DepClassTy::OPTIONAL); + auto *NoAlias = A.getAAFor(AA, ArgPos, DepClassTy::OPTIONAL); + + return AAMem && NoAlias && AAMem->isAssumedWriteOnly() && + NoAlias->isAssumedNoAlias() && !Arg.hasPointeeInMemoryValueAttr(); +} + +struct AAConvertOutArgumentFunction final : AAConvertOutArgument { + AAConvertOutArgumentFunction(const IRPosition &IRP, Attributor &A) + : AAConvertOutArgument(IRP, A) {} + + SmallVector ArgumentsStates; + + /// See AbstractAttribute::updateImpl(...). + void initialize(Attributor &A) override { + const Function *F = getAssociatedFunction(); + if (!F || F->isDeclaration()) + return; + + // Assume that all args are convertable at the begining. + ArgumentsStates.resize(F->arg_size(), true); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + const Function *F = getAssociatedFunction(); + if (!F || F->isDeclaration()) + return indicatePessimisticFixpoint(); + + auto NewStates = ArgumentsStates; + for (unsigned ArgIdx = 0; ArgIdx < F->arg_size(); ++ArgIdx) + if (!isEligibleArgument(*F->getArg(ArgIdx), A, *this)) + NewStates[ArgIdx] = false; + + bool Changed = NewStates == ArgumentsStates; + ArgumentsStates = NewStates; + return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + Function &F = *getAssociatedFunction(); + DenseMap PtrToType; + SmallVector CandidateArgs; + + for (unsigned ArgIdx = 0; ArgIdx < F.arg_size(); ++ArgIdx) { + Argument *Arg = F.getArg(ArgIdx); + if (!isEligibleArgument(*Arg, A, *this)) + continue; + + CandidateArgs.push_back(Arg); + // AAPointerInfo on args + for (auto &Use : Arg->uses()) + if (auto *Store = dyn_cast(Use.getUser())) + PtrToType[Arg] = Store->getValueOperand()->getType(); + } + + // If there is no valid candidates then return false. + if (PtrToType.empty()) + return indicatePessimisticFixpoint(); + + // Create the new struct return type. + SmallVector OutStructElementsTypes; + if (auto *OriginalFuncTy = F.getReturnType(); !OriginalFuncTy->isVoidTy()) + OutStructElementsTypes.push_back(OriginalFuncTy); + + for (auto *Arg : CandidateArgs) + OutStructElementsTypes.push_back(PtrToType[Arg]); + + auto *ReturnStructType = StructType::create( + F.getContext(), OutStructElementsTypes, (F.getName() + "_out").str()); + + // Get the new Args. + SmallVector NewParamTypes; + for (auto &Arg : F.args()) + if (!PtrToType.count(&Arg)) + NewParamTypes.push_back(Arg.getType()); + + auto *NewFunctionType = + FunctionType::get(ReturnStructType, NewParamTypes, F.isVarArg()); + auto *NewFunction = + Function::Create(NewFunctionType, F.getLinkage(), F.getAddressSpace(), + F.getName() + ".converted"); + + // Map old arguments to new ones, And also map the old arguments to struct + // elements. + ValueToValueMapTy VMap; + auto NewArgIt = NewFunction->arg_begin(); + BasicBlock *EntryBlock = + BasicBlock::Create(NewFunction->getContext(), "entry", NewFunction); + + IRBuilder<> EntryBuilder(EntryBlock); + for (auto &OldArg : F.args()) { + if (PtrToType.count(&OldArg)) { + dbgs() << "OldArg: " << OldArg + << " ======> Type: " << *PtrToType[&OldArg] << "\n"; + AllocaInst *Alloca = EntryBuilder.CreateAlloca( + PtrToType[&OldArg], nullptr, OldArg.getName() + "_"); + VMap[&OldArg] = Alloca; + } else + VMap[&OldArg] = &(*NewArgIt++); + } + + // Clone the old function into the new one. + SmallVector Returns; + CloneFunctionInto(NewFunction, &F, VMap, + CloneFunctionChangeType::LocalChangesOnly, Returns); + + // Update the return values (make it struct). + for (ReturnInst *Ret : Returns) { + IRBuilder<> Builder(Ret); + SmallVector StructValues; + // Include original return type, if any + if (auto *OriginalFuncTy = F.getReturnType(); !OriginalFuncTy->isVoidTy()) + StructValues.push_back(Ret->getReturnValue()); + + // Create a load instruction to fill the struct element. + for (auto *Arg : CandidateArgs) { + Value *OutVal = Builder.CreateLoad(PtrToType[Arg], VMap[Arg]); + StructValues.push_back(OutVal); + } + + // Build the return struct incrementally. + Value *StructRetVal = UndefValue::get(ReturnStructType); + for (unsigned i = 0; i < StructValues.size(); ++i) + StructRetVal = + Builder.CreateInsertValue(StructRetVal, StructValues[i], i); + + Builder.CreateRet(StructRetVal); + A.deleteAfterManifest(*Ret); + } + return ChangeStatus::CHANGED; + } + + /// See AbstractAttribute::getAsStr(...). + const std::string getAsStr(Attributor *A) const override { + return "AAConvertOutArgumentFunction"; + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override {} +}; + +struct AAConvertOutArgumentCallSite final : AAConvertOutArgument { + AAConvertOutArgumentCallSite(const IRPosition &IRP, Attributor &A) + : AAConvertOutArgument(IRP, A) {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + CallBase *CB = cast(getCtxI()); + Function *F = CB->getCalledFunction(); + if (!F) + return indicatePessimisticFixpoint(); + + // Get convert attribute. + auto *ConvertAA = A.getAAFor( + *this, IRPosition::function(*F), DepClassTy::REQUIRED); + + // If function will be transformed, mark this call site for update + if (!ConvertAA || ConvertAA->isAssumedConvertible()) + return ChangeStatus::CHANGED; + + return ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + CallBase *CB = cast(getCtxI()); + Function *F = CB->getCalledFunction(); + if (!F) + return ChangeStatus::UNCHANGED; + + IRBuilder<> Builder(CB); + // Create args for new call. + SmallVector NewArgs; + for (unsigned ArgIdx = 0; ArgIdx < CB->arg_size(); ++ArgIdx) { + Value *Arg = CB->getArgOperand(ArgIdx); + Argument *ParamArg = F->getArg(ArgIdx); + if (!isEligibleArgument(*ParamArg, A, *this)) + NewArgs.push_back(Arg); + } + + Module *M = F->getParent(); + auto *NewF = M->getFunction((F->getName() + ".converted").str()); + if (!NewF) + return ChangeStatus::UNCHANGED; + + FunctionCallee NewCallee(NewF->getFunctionType(), NewF); + Instruction *NewCall = + CallInst::Create(NewCallee, NewArgs, CB->getName() + ".converted", CB); + IRPosition ReturnPos = IRPosition::callsite_returned(*CB); + A.changeAfterManifest(ReturnPos, *NewCall); + + // Redirect all uses of the old call to the new call. + for (auto &Use : CB->uses()) + Use.set(NewCall); + + A.deleteAfterManifest(*CB); + return ChangeStatus::CHANGED; + } + + /// See AbstractAttribute::getAsStr(...). + const std::string getAsStr(Attributor *A) const override { + return "AAConvertOutArgumentCallSite"; + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override {} +}; +} // namespace + const char AANoUnwind::ID = 0; const char AANoSync::ID = 0; const char AANoFree::ID = 0; @@ -13024,6 +13245,7 @@ const char AAAllocationInfo::ID = 0; const char AAIndirectCallInfo::ID = 0; const char AAGlobalValueInfo::ID = 0; const char AADenormalFPMath::ID = 0; +const char AAConvertOutArgument::ID = 0; // Macro magic to create the static generator function for attributes that // follow the naming scheme. @@ -13139,6 +13361,7 @@ CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryLocation) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AACallEdges) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAssumptionInfo) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMustProgress) +CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAConvertOutArgument) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonNull) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoAlias) diff --git a/llvm/test/Transforms/Attributor/remove_out_args.ll b/llvm/test/Transforms/Attributor/remove_out_args.ll index 40c39ea41ff67..9f121f6adf27a 100644 --- a/llvm/test/Transforms/Attributor/remove_out_args.ll +++ b/llvm/test/Transforms/Attributor/remove_out_args.ll @@ -3,18 +3,44 @@ -define internal i1 @foo(ptr %dst) { +define internal i1 @foo(ptr %dst, i32 %a, i32 %b) { entry: - store i32 42, ptr %dst - ret i1 true + %x = xor i32 %a, 13 + %y = add i32 %b, 5 + %z = icmp sle i32 %x, %y + br i1 %z, label %if, label %else + +if: + store i32 %x, ptr %dst, align 4 + br label %end + +else: + store i32 %y, ptr %dst, align 4 + br label %end + +end: + %t = mul i32 %x, %y + %tt = xor i32 %x, %y + %result = icmp eq i32 %t, %tt + ret i1 %result } -define i1 @fee(i32 %x, i32 %y) { +define i1 @fee(i32 %x, i32 %y, i32 %z) { +; CHECK-LABEL: define i1 @fee( +; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) #[[ATTR0:[0-9]+]] { +; CHECK-NEXT: [[PTR:%.*]] = alloca i32, align 4 +; CHECK-NEXT: [[A:%.*]] = call i1 @foo.converted(ptr noalias nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[PTR]], i32 [[Y]], i32 [[Z]]) #[[ATTR1:[0-9]+]] +; CHECK-NEXT: [[B:%.*]] = load i32, ptr [[PTR]], align 4 +; CHECK-NEXT: [[C:%.*]] = icmp sle i32 [[B]], [[X]] +; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[A]], [[C]] +; CHECK-NEXT: ret i1 [[XOR]] +; %ptr = alloca i32 - %a = call i1 @foo(ptr %ptr, i32 %y) + %a = call i1 @foo(ptr %ptr, i32 %y, i32 %z) %b = load i32, ptr %ptr %c = icmp sle i32 %b, %x %xor = xor i1 %a, %c ret i1 %xor } +