Skip to content

Commit d0c90ca

Browse files
committed
[VPlan] Sink predicated stores with complementary masks.
Extend the logic to hoist predicated loads (llvm#168373) to sink predicated stores with complementary masks in a similar fashion. The patch refactors some of the existing logic for legality checks to be shared between hosting and sinking, and adds a new sinking transform on top. With respect to the legality checks, for sinking stores the code also checks if there are any aliasing stores that may alias, not only loads.
1 parent 371302a commit d0c90ca

File tree

4 files changed

+316
-281
lines changed

4 files changed

+316
-281
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8318,6 +8318,7 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
83188318
std::unique_ptr<VPlan>(VPlan0->duplicate()), SubRange, &LVer)) {
83198319
// Now optimize the initial VPlan.
83208320
VPlanTransforms::hoistPredicatedLoads(*Plan, *PSE.getSE(), OrigLoop);
8321+
VPlanTransforms::sinkPredicatedStores(*Plan, *PSE.getSE(), OrigLoop);
83218322
VPlanTransforms::runPass(VPlanTransforms::truncateToMinimalBitwidths,
83228323
*Plan, CM.getMinimalBitwidths());
83238324
VPlanTransforms::runPass(VPlanTransforms::optimize, *Plan);

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 230 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -3976,42 +3976,67 @@ void VPlanTransforms::hoistInvariantLoads(VPlan &Plan) {
39763976
}
39773977
}
39783978

3979-
// Returns the intersection of metadata from a group of loads.
3980-
static VPIRMetadata getCommonLoadMetadata(ArrayRef<VPReplicateRecipe *> Loads) {
3981-
VPIRMetadata CommonMetadata = *Loads.front();
3982-
for (VPReplicateRecipe *Load : drop_begin(Loads))
3983-
CommonMetadata.intersect(*Load);
3979+
// Collect common metadata from a group of replicate recipes by intersecting
3980+
// metadata from all recipes in the group.
3981+
static VPIRMetadata getCommonMetadata(ArrayRef<VPReplicateRecipe *> Recipes) {
3982+
VPIRMetadata CommonMetadata = *Recipes.front();
3983+
for (VPReplicateRecipe *Recipe : drop_begin(Recipes))
3984+
CommonMetadata.intersect(*Recipe);
39843985
return CommonMetadata;
39853986
}
39863987

3987-
// Check if a load can be hoisted by verifying it doesn't alias with any stores
3988-
// in blocks between FirstBB and LastBB using scoped noalias metadata.
3989-
static bool canHoistLoadWithNoAliasCheck(VPReplicateRecipe *Load,
3990-
VPBasicBlock *FirstBB,
3991-
VPBasicBlock *LastBB) {
3992-
// Get the load's memory location and check if it aliases with any stores
3993-
// using scoped noalias metadata.
3994-
auto LoadLoc = vputils::getMemoryLocation(*Load);
3995-
if (!LoadLoc || !LoadLoc->AATags.Scope)
3988+
// Helper to check if we can prove no aliasing using scoped noalias metadata.
3989+
static bool canProveNoAlias(const AAMDNodes &AA1, const AAMDNodes &AA2) {
3990+
return AA1.Scope && AA2.NoAlias &&
3991+
!ScopedNoAliasAAResult::mayAliasInScopes(AA1.Scope, AA2.NoAlias);
3992+
}
3993+
3994+
// Check if a memory operation doesn't alias with memory operations in blocks
3995+
// between FirstBB and LastBB using scoped noalias metadata.
3996+
// For load hoisting, we only check writes in one direction.
3997+
// For store sinking, we check both reads and writes bidirectionally.
3998+
static bool canHoistOrSinkWithNoAliasCheck(
3999+
const MemoryLocation &MemLoc, VPBasicBlock *FirstBB, VPBasicBlock *LastBB,
4000+
bool CheckReads,
4001+
const SmallPtrSetImpl<VPRecipeBase *> *ExcludeRecipes = nullptr) {
4002+
if (!MemLoc.AATags.Scope)
39964003
return false;
39974004

3998-
const AAMDNodes &LoadAA = LoadLoc->AATags;
4005+
const AAMDNodes &MemAA = MemLoc.AATags;
4006+
39994007
for (VPBlockBase *Block = FirstBB; Block;
40004008
Block = Block->getSingleSuccessor()) {
4001-
// This function assumes a simple linear chain of blocks. If there are
4002-
// multiple successors, we would need more complex analysis.
40034009
assert(Block->getNumSuccessors() <= 1 &&
40044010
"Expected at most one successor in block chain");
40054011
auto *VPBB = cast<VPBasicBlock>(Block);
40064012
for (VPRecipeBase &R : *VPBB) {
4007-
if (R.mayWriteToMemory()) {
4008-
auto Loc = vputils::getMemoryLocation(R);
4009-
// Bail out if we can't get the location or if the scoped noalias
4010-
// metadata indicates potential aliasing.
4011-
if (!Loc || ScopedNoAliasAAResult::mayAliasInScopes(
4012-
LoadAA.Scope, Loc->AATags.NoAlias))
4013-
return false;
4013+
if (ExcludeRecipes && ExcludeRecipes->contains(&R))
4014+
continue;
4015+
4016+
// Skip recipes that don't need checking.
4017+
if (!R.mayWriteToMemory() && !(CheckReads && R.mayReadFromMemory()))
4018+
continue;
4019+
4020+
auto Loc = vputils::getMemoryLocation(R);
4021+
if (!Loc)
4022+
// Conservatively assume aliasing for memory operations without
4023+
// location. We already filtered by
4024+
// mayWriteToMemory()/mayReadFromMemory() above.
4025+
return false;
4026+
4027+
// Check for aliasing using scoped noalias metadata.
4028+
// For store sinking with CheckReads, we can prove no aliasing
4029+
// bidirectionally (either direction suffices).
4030+
if (CheckReads) {
4031+
if (canProveNoAlias(Loc->AATags, MemAA) ||
4032+
canProveNoAlias(MemAA, Loc->AATags))
4033+
continue;
40144034
}
4035+
4036+
// Check if the memory operations may alias in the standard direction.
4037+
if (ScopedNoAliasAAResult::mayAliasInScopes(MemAA.Scope,
4038+
Loc->AATags.NoAlias))
4039+
return false;
40154040
}
40164041

40174042
if (Block == LastBB)
@@ -4020,103 +4045,223 @@ static bool canHoistLoadWithNoAliasCheck(VPReplicateRecipe *Load,
40204045
return true;
40214046
}
40224047

4023-
void VPlanTransforms::hoistPredicatedLoads(VPlan &Plan, ScalarEvolution &SE,
4024-
const Loop *L) {
4048+
template <unsigned Opcode>
4049+
static SmallVector<SmallVector<VPReplicateRecipe *, 4>>
4050+
collectComplementaryPredicatedMemOps(VPlan &Plan, ScalarEvolution &SE,
4051+
const Loop *L) {
4052+
static_assert(Opcode == Instruction::Load || Opcode == Instruction::Store,
4053+
"Only Load and Store opcodes supported");
4054+
constexpr bool IsLoad = (Opcode == Instruction::Load);
40254055
VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion();
40264056
VPTypeAnalysis TypeInfo(Plan);
4027-
VPDominatorTree VPDT(Plan);
40284057

4029-
// Group predicated loads by their address SCEV.
4030-
MapVector<const SCEV *, SmallVector<VPReplicateRecipe *>> LoadsByAddress;
4058+
// Group predicated operations by their address SCEV.
4059+
MapVector<const SCEV *, SmallVector<VPReplicateRecipe *>> RecipesByAddress;
40314060
for (VPBlockBase *Block : vp_depth_first_shallow(LoopRegion->getEntry())) {
40324061
auto *VPBB = cast<VPBasicBlock>(Block);
40334062
for (VPRecipeBase &R : *VPBB) {
40344063
auto *RepR = dyn_cast<VPReplicateRecipe>(&R);
4035-
if (!RepR || RepR->getOpcode() != Instruction::Load ||
4036-
!RepR->isPredicated())
4064+
if (!RepR || RepR->getOpcode() != Opcode || !RepR->isPredicated())
40374065
continue;
40384066

4039-
VPValue *Addr = RepR->getOperand(0);
4067+
// For loads, operand 0 is address; for stores, operand 1 is address.
4068+
VPValue *Addr = RepR->getOperand(IsLoad ? 0 : 1);
40404069
const SCEV *AddrSCEV = vputils::getSCEVExprForVPValue(Addr, SE, L);
40414070
if (!isa<SCEVCouldNotCompute>(AddrSCEV))
4042-
LoadsByAddress[AddrSCEV].push_back(RepR);
4071+
RecipesByAddress[AddrSCEV].push_back(RepR);
40434072
}
40444073
}
40454074

4046-
// For each address, collect loads with complementary masks, sort by
4047-
// dominance, and use the earliest load.
4048-
for (auto &[Addr, Loads] : LoadsByAddress) {
4049-
if (Loads.size() < 2)
4075+
// For each address, collect operations with the same or complementary masks.
4076+
SmallVector<SmallVector<VPReplicateRecipe *, 4>> AllGroups;
4077+
for (auto &[Addr, Recipes] : RecipesByAddress) {
4078+
if (Recipes.size() < 2)
40504079
continue;
40514080

4052-
// Collect groups of loads with complementary masks.
4053-
SmallVector<SmallVector<VPReplicateRecipe *, 4>> LoadGroups;
4054-
for (VPReplicateRecipe *&LoadI : Loads) {
4055-
if (!LoadI)
4081+
// Collect groups with the same or complementary masks.
4082+
for (VPReplicateRecipe *&RecipeI : Recipes) {
4083+
if (!RecipeI)
40564084
continue;
40574085

4058-
VPValue *MaskI = LoadI->getMask();
4059-
Type *TypeI = TypeInfo.inferScalarType(LoadI);
4086+
VPValue *MaskI = RecipeI->getMask();
4087+
Type *TypeI =
4088+
TypeInfo.inferScalarType(IsLoad ? RecipeI : RecipeI->getOperand(0));
40604089
SmallVector<VPReplicateRecipe *, 4> Group;
4061-
Group.push_back(LoadI);
4062-
LoadI = nullptr;
4090+
Group.push_back(RecipeI);
4091+
RecipeI = nullptr;
40634092

4064-
// Find all loads with the same type.
4065-
for (VPReplicateRecipe *&LoadJ : Loads) {
4066-
if (!LoadJ)
4093+
// Find all operations with the same or complementary masks.
4094+
bool HasComplementaryMask = false;
4095+
for (VPReplicateRecipe *&RecipeJ : Recipes) {
4096+
if (!RecipeJ)
40674097
continue;
40684098

4069-
Type *TypeJ = TypeInfo.inferScalarType(LoadJ);
4099+
VPValue *MaskJ = RecipeJ->getMask();
4100+
Type *TypeJ =
4101+
TypeInfo.inferScalarType(IsLoad ? RecipeJ : RecipeJ->getOperand(0));
40704102
if (TypeI == TypeJ) {
4071-
Group.push_back(LoadJ);
4072-
LoadJ = nullptr;
4103+
// Check if any operation in the group has a complementary mask with
4104+
// another, that is M1 == NOT(M2) or M2 == NOT(M1).
4105+
HasComplementaryMask |= match(MaskI, m_Not(m_Specific(MaskJ))) ||
4106+
match(MaskJ, m_Not(m_Specific(MaskI)));
4107+
Group.push_back(RecipeJ);
4108+
RecipeJ = nullptr;
40734109
}
40744110
}
40754111

4076-
// Check if any load in the group has a complementary mask with another,
4077-
// that is M1 == NOT(M2) or M2 == NOT(M1).
4078-
bool HasComplementaryMask =
4079-
any_of(drop_begin(Group), [MaskI](VPReplicateRecipe *Load) {
4080-
VPValue *MaskJ = Load->getMask();
4081-
return match(MaskI, m_Not(m_Specific(MaskJ))) ||
4082-
match(MaskJ, m_Not(m_Specific(MaskI)));
4083-
});
4112+
if (HasComplementaryMask) {
4113+
assert(Group.size() >= 2 && "must have at least 2 entries");
4114+
AllGroups.push_back(std::move(Group));
4115+
}
4116+
}
4117+
}
4118+
4119+
return AllGroups;
4120+
}
4121+
4122+
void VPlanTransforms::hoistPredicatedLoads(VPlan &Plan, ScalarEvolution &SE,
4123+
const Loop *L) {
4124+
auto Groups =
4125+
collectComplementaryPredicatedMemOps<Instruction::Load>(Plan, SE, L);
4126+
if (Groups.empty())
4127+
return;
4128+
4129+
VPDominatorTree VPDT(Plan);
4130+
4131+
// Process each group of loads.
4132+
for (auto &Group : Groups) {
4133+
// Sort loads by dominance order, with earliest (most dominating) first.
4134+
sort(Group, [&VPDT](VPReplicateRecipe *A, VPReplicateRecipe *B) {
4135+
return VPDT.properlyDominates(A, B);
4136+
});
4137+
4138+
// Try to use the earliest (most dominating) load to replace all others.
4139+
VPReplicateRecipe *EarliestLoad = Group[0];
4140+
VPBasicBlock *FirstBB = EarliestLoad->getParent();
4141+
VPBasicBlock *LastBB = Group.back()->getParent();
4142+
4143+
// Check that the load doesn't alias with stores between first and last.
4144+
auto LoadLoc = vputils::getMemoryLocation(*EarliestLoad);
4145+
if (!LoadLoc ||
4146+
!canHoistOrSinkWithNoAliasCheck(*LoadLoc, FirstBB, LastBB,
4147+
/*CheckReads=*/false))
4148+
continue;
4149+
4150+
// Collect common metadata from all loads in the group.
4151+
VPIRMetadata CommonMetadata = getCommonMetadata(Group);
4152+
4153+
// Create an unpredicated version of the earliest load with common
4154+
// metadata.
4155+
auto *UnpredicatedLoad = new VPReplicateRecipe(
4156+
EarliestLoad->getUnderlyingInstr(), {EarliestLoad->getOperand(0)},
4157+
/*IsSingleScalar=*/false, /*Mask=*/nullptr, *EarliestLoad,
4158+
CommonMetadata);
40844159

4085-
if (HasComplementaryMask)
4086-
LoadGroups.push_back(std::move(Group));
4160+
UnpredicatedLoad->insertBefore(EarliestLoad);
4161+
4162+
// Replace all loads in the group with the unpredicated load.
4163+
for (VPReplicateRecipe *Load : Group) {
4164+
Load->replaceAllUsesWith(UnpredicatedLoad);
4165+
Load->eraseFromParent();
40874166
}
4167+
}
4168+
}
40884169

4089-
// For each group, check memory dependencies and hoist the earliest load.
4090-
for (auto &Group : LoadGroups) {
4091-
// Sort loads by dominance order, with earliest (most dominating) first.
4092-
sort(Group, [&VPDT](VPReplicateRecipe *A, VPReplicateRecipe *B) {
4093-
return VPDT.properlyDominates(A, B);
4094-
});
4170+
static bool canSinkStoreWithNoAliasCheck(
4171+
VPReplicateRecipe *Store, ArrayRef<VPReplicateRecipe *> StoresToSink,
4172+
const SmallPtrSetImpl<VPRecipeBase *> *AlreadySunkStores = nullptr) {
4173+
auto StoreLoc = vputils::getMemoryLocation(*Store);
4174+
if (!StoreLoc)
4175+
return false;
40954176

4096-
VPReplicateRecipe *EarliestLoad = Group.front();
4097-
VPBasicBlock *FirstBB = EarliestLoad->getParent();
4098-
VPBasicBlock *LastBB = Group.back()->getParent();
4177+
SmallPtrSet<VPRecipeBase *, 4> StoresToSinkSet(StoresToSink.begin(),
4178+
StoresToSink.end());
4179+
if (AlreadySunkStores)
4180+
StoresToSinkSet.insert(AlreadySunkStores->begin(),
4181+
AlreadySunkStores->end());
40994182

4100-
// Check that the load doesn't alias with stores between first and last.
4101-
if (!canHoistLoadWithNoAliasCheck(EarliestLoad, FirstBB, LastBB))
4183+
VPBasicBlock *FirstBB = StoresToSink.front()->getParent();
4184+
VPBasicBlock *LastBB = StoresToSink.back()->getParent();
4185+
4186+
if (StoreLoc->AATags.Scope)
4187+
return canHoistOrSinkWithNoAliasCheck(*StoreLoc, FirstBB, LastBB,
4188+
/*CheckReads=*/true,
4189+
&StoresToSinkSet);
4190+
4191+
// Without alias scope metadata, we conservatively require no memory
4192+
// operations between the stores being sunk.
4193+
for (VPBlockBase *Block = FirstBB; Block;
4194+
Block = Block->getSingleSuccessor()) {
4195+
auto *VPBB = cast<VPBasicBlock>(Block);
4196+
for (VPRecipeBase &R : *VPBB) {
4197+
if (StoresToSinkSet.contains(&R))
41024198
continue;
41034199

4104-
// Collect common metadata from all loads in the group.
4105-
VPIRMetadata CommonMetadata = getCommonLoadMetadata(Group);
4200+
if (R.mayReadFromMemory() || R.mayWriteToMemory())
4201+
return false;
4202+
}
41064203

4107-
// Create an unpredicated version of the earliest load with common
4108-
// metadata.
4109-
auto *UnpredicatedLoad = new VPReplicateRecipe(
4110-
EarliestLoad->getUnderlyingInstr(), {EarliestLoad->getOperand(0)},
4111-
/*IsSingleScalar=*/false, /*Mask=*/nullptr, *EarliestLoad, CommonMetadata);
4204+
if (Block == LastBB)
4205+
break;
4206+
}
41124207

4113-
UnpredicatedLoad->insertBefore(EarliestLoad);
4208+
return true;
4209+
}
41144210

4115-
// Replace all loads in the group with the unpredicated load.
4116-
for (VPReplicateRecipe *Load : Group) {
4117-
Load->replaceAllUsesWith(UnpredicatedLoad);
4118-
Load->eraseFromParent();
4119-
}
4211+
void VPlanTransforms::sinkPredicatedStores(VPlan &Plan, ScalarEvolution &SE,
4212+
const Loop *L) {
4213+
auto Groups =
4214+
collectComplementaryPredicatedMemOps<Instruction::Store>(Plan, SE, L);
4215+
4216+
if (Groups.empty())
4217+
return;
4218+
4219+
VPDominatorTree VPDT(Plan);
4220+
4221+
// Track stores from all groups that have been successfully sunk to exclude
4222+
// them from alias checks for subsequent groups.
4223+
SmallPtrSet<VPRecipeBase *, 16> AlreadySunkStores;
4224+
4225+
for (auto &Group : Groups) {
4226+
sort(Group, [&VPDT](VPReplicateRecipe *A, VPReplicateRecipe *B) {
4227+
return VPDT.properlyDominates(A, B);
4228+
});
4229+
4230+
if (!canSinkStoreWithNoAliasCheck(Group[0], Group, &AlreadySunkStores))
4231+
continue;
4232+
4233+
// Use the last (most dominated) store's location for the unconditional
4234+
// store.
4235+
VPReplicateRecipe *LastStore = Group.back();
4236+
VPBasicBlock *InsertBB = LastStore->getParent();
4237+
4238+
// Collect common alias metadata from all stores in the group.
4239+
VPIRMetadata CommonMetadata = getCommonMetadata(Group);
4240+
4241+
// Build select chain for stored values.
4242+
VPValue *SelectedValue = Group[0]->getOperand(0);
4243+
VPBuilder Builder(InsertBB, LastStore->getIterator());
4244+
4245+
for (unsigned I = 1; I < Group.size(); ++I) {
4246+
VPValue *Mask = Group[I]->getMask();
4247+
VPValue *Value = Group[I]->getOperand(0);
4248+
SelectedValue = Builder.createSelect(Mask, Value, SelectedValue,
4249+
Group[I]->getDebugLoc());
4250+
}
4251+
4252+
// Create unconditional store with selected value and common metadata.
4253+
VPValue *AddrVPValue = Group[0]->getOperand(1);
4254+
SmallVector<VPValue *> Operands = {SelectedValue, AddrVPValue};
4255+
auto *SI = cast<StoreInst>(Group[0]->getUnderlyingInstr());
4256+
auto *UnpredicatedStore =
4257+
new VPReplicateRecipe(SI, Operands, /*IsSingleScalar=*/false,
4258+
/*Mask=*/nullptr, *LastStore, CommonMetadata);
4259+
UnpredicatedStore->insertBefore(*InsertBB, LastStore->getIterator());
4260+
4261+
// Track and remove all predicated stores from the group.
4262+
for (VPReplicateRecipe *Store : Group) {
4263+
AlreadySunkStores.insert(Store);
4264+
Store->eraseFromParent();
41204265
}
41214266
}
41224267
}

llvm/lib/Transforms/Vectorize/VPlanTransforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,13 @@ struct VPlanTransforms {
320320
static void hoistPredicatedLoads(VPlan &Plan, ScalarEvolution &SE,
321321
const Loop *L);
322322

323+
/// Sink predicated stores to the same address with complementary predicates
324+
/// (P and NOT P) to an unconditional store with select recipes for the
325+
/// stored values. This eliminates branching overhead when all paths
326+
/// unconditionally store to the same location.
327+
static void sinkPredicatedStores(VPlan &Plan, ScalarEvolution &SE,
328+
const Loop *L);
329+
323330
// Materialize vector trip counts for constants early if it can simply be
324331
// computed as (Original TC / VF * UF) * VF * UF.
325332
static void

0 commit comments

Comments
 (0)