5252#include " llvm/IR/Argument.h"
5353#include " llvm/IR/BasicBlock.h"
5454#include " llvm/IR/Constant.h"
55+ #include " llvm/IR/ConstantRangeList.h"
5556#include " llvm/IR/Constants.h"
5657#include " llvm/IR/DataLayout.h"
5758#include " llvm/IR/DebugInfo.h"
@@ -164,6 +165,10 @@ static cl::opt<bool>
164165 OptimizeMemorySSA (" dse-optimize-memoryssa" , cl::init(true ), cl::Hidden,
165166 cl::desc(" Allow DSE to optimize memory accesses." ));
166167
168+ static cl::opt<bool > EnableInitializesImprovement (
169+ " enable-dse-initializes-attr-improvement" , cl::init(false ), cl::Hidden,
170+ cl::desc(" Enable the initializes attr improvement in DSE" ));
171+
167172// ===----------------------------------------------------------------------===//
168173// Helper functions
169174// ===----------------------------------------------------------------------===//
@@ -809,8 +814,10 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) {
809814// A memory location wrapper that represents a MemoryLocation, `MemLoc`,
810815// defined by `MemDef`.
811816struct MemoryLocationWrapper {
812- MemoryLocationWrapper (MemoryLocation MemLoc, MemoryDef *MemDef)
813- : MemLoc(MemLoc), MemDef(MemDef) {
817+ MemoryLocationWrapper (MemoryLocation MemLoc, MemoryDef *MemDef,
818+ bool DefByInitializesAttr)
819+ : MemLoc(MemLoc), MemDef(MemDef),
820+ DefByInitializesAttr (DefByInitializesAttr) {
814821 assert (MemLoc.Ptr && " MemLoc should be not null" );
815822 UnderlyingObject = getUnderlyingObject (MemLoc.Ptr );
816823 DefInst = MemDef->getMemoryInst ();
@@ -820,20 +827,121 @@ struct MemoryLocationWrapper {
820827 const Value *UnderlyingObject;
821828 MemoryDef *MemDef;
822829 Instruction *DefInst;
830+ bool DefByInitializesAttr = false ;
823831};
824832
825833// A memory def wrapper that represents a MemoryDef and the MemoryLocation(s)
826834// defined by this MemoryDef.
827835struct MemoryDefWrapper {
828- MemoryDefWrapper (MemoryDef *MemDef, std::optional<MemoryLocation> MemLoc) {
836+ MemoryDefWrapper (
837+ MemoryDef *MemDef,
838+ const SmallVectorImpl<std::pair<MemoryLocation, bool >> &MemLocations) {
829839 DefInst = MemDef->getMemoryInst ();
830- if (MemLoc.has_value ())
831- DefinedLocation = MemoryLocationWrapper (*MemLoc, MemDef);
840+ for (auto &[MemLoc, DefByInitializesAttr] : MemLocations)
841+ DefinedLocations.push_back (
842+ MemoryLocationWrapper (MemLoc, MemDef, DefByInitializesAttr));
832843 }
833844 Instruction *DefInst;
834- std::optional<MemoryLocationWrapper> DefinedLocation = std::nullopt ;
845+ SmallVector<MemoryLocationWrapper, 1 > DefinedLocations;
846+ };
847+
848+ bool HasInitializesAttr (Instruction *I) {
849+ CallBase *CB = dyn_cast<CallBase>(I);
850+ if (!CB)
851+ return false ;
852+
853+ for (size_t Idx = 0 ; Idx < CB->arg_size (); Idx++)
854+ if (CB->paramHasAttr (Idx, Attribute::Initializes))
855+ return true ;
856+ return false ;
857+ }
858+
859+ struct ArgumentInitInfo {
860+ size_t Idx = -1 ;
861+ ConstantRangeList Inits;
862+ bool HasDeadOnUnwindAttr = false ;
863+ bool FuncHasNoUnwindAttr = false ;
835864};
836865
866+ ConstantRangeList
867+ GetMergedInitAttr (const SmallVectorImpl<ArgumentInitInfo> &Args) {
868+ if (Args.empty ())
869+ return {};
870+
871+ // To address unwind, the function should have nounwind attribute or the
872+ // arguments have dead_on_unwind attribute. Otherwise, return empty.
873+ for (const auto &Arg : Args) {
874+ if (!Arg.FuncHasNoUnwindAttr && !Arg.HasDeadOnUnwindAttr )
875+ return {};
876+ if (Arg.Inits .empty ())
877+ return {};
878+ }
879+
880+ if (Args.size () == 1 )
881+ return Args[0 ].Inits ;
882+
883+ ConstantRangeList MergedIntervals = Args[0 ].Inits ;
884+ for (size_t i = 1 ; i < Args.size (); i++)
885+ MergedIntervals = MergedIntervals.intersectWith (Args[i].Inits );
886+
887+ return MergedIntervals;
888+ }
889+
890+ // Return the locations wrote by the initializes attribute.
891+ // Note that this function considers:
892+ // 1. Unwind edge: apply "initializes" attribute only if the callee has
893+ // "nounwind" attribute or the argument has "dead_on_unwind" attribute.
894+ // 2. Argument alias: for aliasing arguments, the "initializes" attribute is
895+ // the merged range list of their "initializes" attributes.
896+ SmallVector<MemoryLocation, 1 >
897+ GetInitializesArgMemLoc (const Instruction *I, BatchAAResults &BatchAA) {
898+ const CallBase *CB = dyn_cast<CallBase>(I);
899+ if (!CB)
900+ return {};
901+
902+ // Collect aliasing arguments and their initializes ranges.
903+ bool HasNoUnwindAttr = CB->hasFnAttr (Attribute::NoUnwind);
904+ SmallMapVector<Value *, SmallVector<ArgumentInitInfo, 2 >, 2 > Arguments;
905+ for (size_t Idx = 0 ; Idx < CB->arg_size (); Idx++) {
906+ ConstantRangeList Inits;
907+ if (CB->paramHasAttr (Idx, Attribute::Initializes))
908+ Inits = CB->getParamAttr (Idx, Attribute::Initializes)
909+ .getValueAsConstantRangeList ();
910+
911+ bool HasDeadOnUnwindAttr = CB->paramHasAttr (Idx, Attribute::DeadOnUnwind);
912+ ArgumentInitInfo InitInfo{Idx, Inits, HasDeadOnUnwindAttr, HasNoUnwindAttr};
913+ Value *CurArg = CB->getArgOperand (Idx);
914+ bool FoundAliasing = false ;
915+ for (auto &[Arg, AliasList] : Arguments) {
916+ if (BatchAA.isMustAlias (Arg, CurArg)) {
917+ FoundAliasing = true ;
918+ AliasList.push_back (InitInfo);
919+ }
920+ }
921+ if (!FoundAliasing)
922+ Arguments[CurArg] = {InitInfo};
923+ }
924+
925+ SmallVector<MemoryLocation, 1 > Locations;
926+ for (const auto &[_, Args] : Arguments) {
927+ auto MergedInitAttr = GetMergedInitAttr (Args);
928+ if (MergedInitAttr.empty ())
929+ continue ;
930+
931+ for (const auto &Arg : Args) {
932+ for (const auto &Range : MergedInitAttr) {
933+ int64_t Start = Range.getLower ().getSExtValue ();
934+ int64_t End = Range.getUpper ().getSExtValue ();
935+ if (Start == 0 )
936+ Locations.push_back (MemoryLocation (CB->getArgOperand (Arg.Idx ),
937+ LocationSize::precise (End - Start),
938+ CB->getAAMetadata ()));
939+ }
940+ }
941+ }
942+ return Locations;
943+ }
944+
837945struct DSEState {
838946 Function &F;
839947 AliasAnalysis &AA;
@@ -911,7 +1019,8 @@ struct DSEState {
9111019
9121020 auto *MD = dyn_cast_or_null<MemoryDef>(MA);
9131021 if (MD && MemDefs.size () < MemorySSADefsPerBlockLimit &&
914- (getLocForWrite (&I) || isMemTerminatorInst (&I)))
1022+ (getLocForWrite (&I) || isMemTerminatorInst (&I) ||
1023+ HasInitializesAttr (&I)))
9151024 MemDefs.push_back (MD);
9161025 }
9171026 }
@@ -1147,13 +1256,26 @@ struct DSEState {
11471256 return MemoryLocation::getOrNone (I);
11481257 }
11491258
1150- std::optional<MemoryLocation> getLocForInst (Instruction *I) {
1259+ // Returns a list of <MemoryLocation, bool> pairs wrote by I.
1260+ // The bool means whether the write is from Initializes attr.
1261+ SmallVector<std::pair<MemoryLocation, bool >, 1 >
1262+ getLocForInst (Instruction *I, bool ConsiderInitializesAttr) {
1263+ SmallVector<std::pair<MemoryLocation, bool >, 1 > Locations;
11511264 if (isMemTerminatorInst (I)) {
1152- if (auto Loc = getLocForTerminator (I)) {
1153- return Loc->first ;
1265+ if (auto Loc = getLocForTerminator (I))
1266+ Locations.push_back (std::make_pair (Loc->first , false ));
1267+ return Locations;
1268+ }
1269+
1270+ if (auto Loc = getLocForWrite (I))
1271+ Locations.push_back (std::make_pair (*Loc, false ));
1272+
1273+ if (ConsiderInitializesAttr) {
1274+ for (auto &MemLoc : GetInitializesArgMemLoc (I, BatchAA)) {
1275+ Locations.push_back (std::make_pair (MemLoc, true ));
11541276 }
11551277 }
1156- return getLocForWrite (I) ;
1278+ return Locations ;
11571279 }
11581280
11591281 // / Assuming this instruction has a dead analyzable write, can we delete
@@ -1365,7 +1487,8 @@ struct DSEState {
13651487 getDomMemoryDef (MemoryDef *KillingDef, MemoryAccess *StartAccess,
13661488 const MemoryLocation &KillingLoc, const Value *KillingUndObj,
13671489 unsigned &ScanLimit, unsigned &WalkerStepLimit,
1368- bool IsMemTerm, unsigned &PartialLimit) {
1490+ bool IsMemTerm, unsigned &PartialLimit,
1491+ bool IsInitializesAttrMemLoc) {
13691492 if (ScanLimit == 0 || WalkerStepLimit == 0 ) {
13701493 LLVM_DEBUG (dbgs () << " \n ... hit scan limit\n " );
13711494 return std::nullopt ;
@@ -1602,7 +1725,17 @@ struct DSEState {
16021725
16031726 // Uses which may read the original MemoryDef mean we cannot eliminate the
16041727 // original MD. Stop walk.
1605- if (isReadClobber (MaybeDeadLoc, UseInst)) {
1728+ // If KillingDef is a CallInst with "initializes" attribute, the reads in
1729+ // Callee would be dominated by initializations, so this should be safe.
1730+ bool IsKillingDefFromInitAttr = false ;
1731+ if (IsInitializesAttrMemLoc) {
1732+ if (KillingI == UseInst &&
1733+ KillingUndObj == getUnderlyingObject (MaybeDeadLoc.Ptr )) {
1734+ IsKillingDefFromInitAttr = true ;
1735+ }
1736+ }
1737+
1738+ if (isReadClobber (MaybeDeadLoc, UseInst) && !IsKillingDefFromInitAttr) {
16061739 LLVM_DEBUG (dbgs () << " ... found read clobber\n " );
16071740 return std::nullopt ;
16081741 }
@@ -2207,7 +2340,8 @@ DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
22072340 std::optional<MemoryAccess *> MaybeDeadAccess = getDomMemoryDef (
22082341 KillingLocWrapper.MemDef , Current, KillingLocWrapper.MemLoc ,
22092342 KillingLocWrapper.UnderlyingObject , ScanLimit, WalkerStepLimit,
2210- isMemTerminatorInst (KillingLocWrapper.DefInst ), PartialLimit);
2343+ isMemTerminatorInst (KillingLocWrapper.DefInst ), PartialLimit,
2344+ KillingLocWrapper.DefByInitializesAttr );
22112345
22122346 if (!MaybeDeadAccess) {
22132347 LLVM_DEBUG (dbgs () << " finished walk\n " );
@@ -2232,8 +2366,11 @@ DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
22322366 }
22332367 MemoryDefWrapper DeadDefWrapper (
22342368 cast<MemoryDef>(DeadAccess),
2235- getLocForInst (cast<MemoryDef>(DeadAccess)->getMemoryInst ()));
2236- MemoryLocationWrapper &DeadLocWrapper = *DeadDefWrapper.DefinedLocation ;
2369+ getLocForInst (cast<MemoryDef>(DeadAccess)->getMemoryInst (),
2370+ /* ConsiderInitializesAttr=*/ false ));
2371+ assert (DeadDefWrapper.DefinedLocations .size () == 1 );
2372+ MemoryLocationWrapper &DeadLocWrapper =
2373+ DeadDefWrapper.DefinedLocations .front ();
22372374 LLVM_DEBUG (dbgs () << " (" << *DeadLocWrapper.DefInst << " )\n " );
22382375 ToCheck.insert (DeadLocWrapper.MemDef ->getDefiningAccess ());
22392376 NumGetDomMemoryDefPassed++;
@@ -2311,37 +2448,41 @@ DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
23112448}
23122449
23132450bool DSEState::eliminateDeadDefs (const MemoryDefWrapper &KillingDefWrapper) {
2314- if (! KillingDefWrapper.DefinedLocation . has_value ()) {
2451+ if (KillingDefWrapper.DefinedLocations . empty ()) {
23152452 LLVM_DEBUG (dbgs () << " Failed to find analyzable write location for "
23162453 << *KillingDefWrapper.DefInst << " \n " );
23172454 return false ;
23182455 }
23192456
2320- auto &KillingLocWrapper = *KillingDefWrapper.DefinedLocation ;
2321- LLVM_DEBUG (dbgs () << " Trying to eliminate MemoryDefs killed by "
2322- << *KillingLocWrapper.MemDef << " ("
2323- << *KillingLocWrapper.DefInst << " )\n " );
2324- auto [Changed, DeletedKillingLoc] = eliminateDeadDefs (KillingLocWrapper);
2325-
2326- // Check if the store is a no-op.
2327- if (!DeletedKillingLoc && storeIsNoop (KillingLocWrapper.MemDef ,
2328- KillingLocWrapper.UnderlyingObject )) {
2329- LLVM_DEBUG (dbgs () << " DSE: Remove No-Op Store:\n DEAD: "
2330- << *KillingLocWrapper.DefInst << ' \n ' );
2331- deleteDeadInstruction (KillingLocWrapper.DefInst );
2332- NumRedundantStores++;
2333- return true ;
2334- }
2335- // Can we form a calloc from a memset/malloc pair?
2336- if (!DeletedKillingLoc &&
2337- tryFoldIntoCalloc (KillingLocWrapper.MemDef ,
2338- KillingLocWrapper.UnderlyingObject )) {
2339- LLVM_DEBUG (dbgs () << " DSE: Remove memset after forming calloc:\n "
2340- << " DEAD: " << *KillingLocWrapper.DefInst << ' \n ' );
2341- deleteDeadInstruction (KillingLocWrapper.DefInst );
2342- return true ;
2457+ bool MadeChange = false ;
2458+ for (auto &KillingLocWrapper : KillingDefWrapper.DefinedLocations ) {
2459+ LLVM_DEBUG (dbgs () << " Trying to eliminate MemoryDefs killed by "
2460+ << *KillingLocWrapper.MemDef << " ("
2461+ << *KillingLocWrapper.DefInst << " )\n " );
2462+ auto [Changed, DeletedKillingLoc] = eliminateDeadDefs (KillingLocWrapper);
2463+
2464+ // Check if the store is a no-op.
2465+ if (!DeletedKillingLoc && storeIsNoop (KillingLocWrapper.MemDef ,
2466+ KillingLocWrapper.UnderlyingObject )) {
2467+ LLVM_DEBUG (dbgs () << " DSE: Remove No-Op Store:\n DEAD: "
2468+ << *KillingLocWrapper.DefInst << ' \n ' );
2469+ deleteDeadInstruction (KillingLocWrapper.DefInst );
2470+ NumRedundantStores++;
2471+ MadeChange = true ;
2472+ continue ;
2473+ }
2474+ // Can we form a calloc from a memset/malloc pair?
2475+ if (!DeletedKillingLoc &&
2476+ tryFoldIntoCalloc (KillingLocWrapper.MemDef ,
2477+ KillingLocWrapper.UnderlyingObject )) {
2478+ LLVM_DEBUG (dbgs () << " DSE: Remove memset after forming calloc:\n "
2479+ << " DEAD: " << *KillingLocWrapper.DefInst << ' \n ' );
2480+ deleteDeadInstruction (KillingLocWrapper.DefInst );
2481+ MadeChange = true ;
2482+ continue ;
2483+ }
23432484 }
2344- return Changed ;
2485+ return MadeChange ;
23452486}
23462487
23472488static bool eliminateDeadStores (Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
@@ -2357,7 +2498,8 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
23572498 continue ;
23582499
23592500 MemoryDefWrapper KillingDefWrapper (
2360- KillingDef, State.getLocForInst (KillingDef->getMemoryInst ()));
2501+ KillingDef, State.getLocForInst (KillingDef->getMemoryInst (),
2502+ EnableInitializesImprovement));
23612503 MadeChange |= State.eliminateDeadDefs (KillingDefWrapper);
23622504 }
23632505
0 commit comments