Skip to content

Commit 84aa119

Browse files
fhahnkcloudy0717
authored andcommitted
[VPlan] Sink predicated stores with complementary masks. (llvm#168771)
Extend the logic to hoist predicated loads (llvm#168373) to sink predicated stores with complementary masks in a similar fashion. The patch refactors some of the existing logic for legality checks to be shared between hosting and sinking, and adds a new sinking transform on top. With respect to the legality checks, for sinking stores the code also checks if there are any aliasing stores that may alias, not only loads. PR: llvm#168771
1 parent 451c207 commit 84aa119

File tree

4 files changed

+422
-297
lines changed

4 files changed

+422
-297
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8369,6 +8369,7 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
83698369
std::unique_ptr<VPlan>(VPlan0->duplicate()), SubRange, &LVer)) {
83708370
// Now optimize the initial VPlan.
83718371
VPlanTransforms::hoistPredicatedLoads(*Plan, *PSE.getSE(), OrigLoop);
8372+
VPlanTransforms::sinkPredicatedStores(*Plan, *PSE.getSE(), OrigLoop);
83728373
VPlanTransforms::runPass(VPlanTransforms::truncateToMinimalBitwidths,
83738374
*Plan, CM.getMinimalBitwidths());
83748375
VPlanTransforms::runPass(VPlanTransforms::optimize, *Plan);

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 210 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -139,35 +139,51 @@ bool VPlanTransforms::tryToConvertVPInstructionsToVPRecipes(
139139
return true;
140140
}
141141

142-
// Check if a load can be hoisted by verifying it doesn't alias with any stores
143-
// in blocks between FirstBB and LastBB using scoped noalias metadata.
144-
static bool canHoistLoadWithNoAliasCheck(VPReplicateRecipe *Load,
145-
VPBasicBlock *FirstBB,
146-
VPBasicBlock *LastBB) {
147-
// Get the load's memory location and check if it aliases with any stores
148-
// using scoped noalias metadata.
149-
auto LoadLoc = vputils::getMemoryLocation(*Load);
150-
if (!LoadLoc || !LoadLoc->AATags.Scope)
142+
// Check if a memory operation doesn't alias with memory operations in blocks
143+
// between FirstBB and LastBB using scoped noalias metadata.
144+
// For load hoisting, we only check writes in one direction.
145+
// For store sinking, we check both reads and writes bidirectionally.
146+
static bool canHoistOrSinkWithNoAliasCheck(
147+
const MemoryLocation &MemLoc, VPBasicBlock *FirstBB, VPBasicBlock *LastBB,
148+
bool CheckReads,
149+
const SmallPtrSetImpl<VPRecipeBase *> *ExcludeRecipes = nullptr) {
150+
if (!MemLoc.AATags.Scope)
151151
return false;
152152

153-
const AAMDNodes &LoadAA = LoadLoc->AATags;
153+
const AAMDNodes &MemAA = MemLoc.AATags;
154+
154155
for (VPBlockBase *Block = FirstBB; Block;
155156
Block = Block->getSingleSuccessor()) {
156-
// This function assumes a simple linear chain of blocks. If there are
157-
// multiple successors, we would need more complex analysis.
158157
assert(Block->getNumSuccessors() <= 1 &&
159158
"Expected at most one successor in block chain");
160159
auto *VPBB = cast<VPBasicBlock>(Block);
161160
for (VPRecipeBase &R : *VPBB) {
162-
if (R.mayWriteToMemory()) {
163-
auto Loc = vputils::getMemoryLocation(R);
164-
// Bail out if we can't get the location or if the scoped noalias
165-
// metadata indicates potential aliasing.
166-
if (!Loc || ScopedNoAliasAAResult::mayAliasInScopes(
167-
LoadAA.Scope, Loc->AATags.NoAlias))
168-
return false;
169-
}
161+
if (ExcludeRecipes && ExcludeRecipes->contains(&R))
162+
continue;
163+
164+
// Skip recipes that don't need checking.
165+
if (!R.mayWriteToMemory() && !(CheckReads && R.mayReadFromMemory()))
166+
continue;
167+
168+
auto Loc = vputils::getMemoryLocation(R);
169+
if (!Loc)
170+
// Conservatively assume aliasing for memory operations without
171+
// location.
172+
return false;
173+
174+
// For reads, check if they don't alias in the reverse direction and
175+
// skip if so.
176+
if (CheckReads && R.mayReadFromMemory() &&
177+
!ScopedNoAliasAAResult::mayAliasInScopes(Loc->AATags.Scope,
178+
MemAA.NoAlias))
179+
continue;
180+
181+
// Check if the memory operations may alias in the forward direction.
182+
if (ScopedNoAliasAAResult::mayAliasInScopes(MemAA.Scope,
183+
Loc->AATags.NoAlias))
184+
return false;
170185
}
186+
171187
if (Block == LastBB)
172188
break;
173189
}
@@ -4135,119 +4151,217 @@ void VPlanTransforms::hoistInvariantLoads(VPlan &Plan) {
41354151
}
41364152
}
41374153

4138-
// Returns the intersection of metadata from a group of loads.
4139-
static VPIRMetadata getCommonLoadMetadata(ArrayRef<VPReplicateRecipe *> Loads) {
4140-
VPIRMetadata CommonMetadata = *Loads.front();
4141-
for (VPReplicateRecipe *Load : drop_begin(Loads))
4142-
CommonMetadata.intersect(*Load);
4154+
// Collect common metadata from a group of replicate recipes by intersecting
4155+
// metadata from all recipes in the group.
4156+
static VPIRMetadata getCommonMetadata(ArrayRef<VPReplicateRecipe *> Recipes) {
4157+
VPIRMetadata CommonMetadata = *Recipes.front();
4158+
for (VPReplicateRecipe *Recipe : drop_begin(Recipes))
4159+
CommonMetadata.intersect(*Recipe);
41434160
return CommonMetadata;
41444161
}
41454162

4146-
void VPlanTransforms::hoistPredicatedLoads(VPlan &Plan, ScalarEvolution &SE,
4147-
const Loop *L) {
4163+
template <unsigned Opcode>
4164+
static SmallVector<SmallVector<VPReplicateRecipe *, 4>>
4165+
collectComplementaryPredicatedMemOps(VPlan &Plan, ScalarEvolution &SE,
4166+
const Loop *L) {
4167+
static_assert(Opcode == Instruction::Load || Opcode == Instruction::Store,
4168+
"Only Load and Store opcodes supported");
4169+
constexpr bool IsLoad = (Opcode == Instruction::Load);
41484170
VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion();
41494171
VPTypeAnalysis TypeInfo(Plan);
4150-
VPDominatorTree VPDT(Plan);
41514172

4152-
// Group predicated loads by their address SCEV.
4153-
DenseMap<const SCEV *, SmallVector<VPReplicateRecipe *>> LoadsByAddress;
4173+
// Group predicated operations by their address SCEV.
4174+
DenseMap<const SCEV *, SmallVector<VPReplicateRecipe *>> RecipesByAddress;
41544175
for (VPBlockBase *Block : vp_depth_first_shallow(LoopRegion->getEntry())) {
41554176
auto *VPBB = cast<VPBasicBlock>(Block);
41564177
for (VPRecipeBase &R : *VPBB) {
41574178
auto *RepR = dyn_cast<VPReplicateRecipe>(&R);
4158-
if (!RepR || RepR->getOpcode() != Instruction::Load ||
4159-
!RepR->isPredicated())
4179+
if (!RepR || RepR->getOpcode() != Opcode || !RepR->isPredicated())
41604180
continue;
41614181

4162-
VPValue *Addr = RepR->getOperand(0);
4182+
// For loads, operand 0 is address; for stores, operand 1 is address.
4183+
VPValue *Addr = RepR->getOperand(IsLoad ? 0 : 1);
41634184
const SCEV *AddrSCEV = vputils::getSCEVExprForVPValue(Addr, SE, L);
41644185
if (!isa<SCEVCouldNotCompute>(AddrSCEV))
4165-
LoadsByAddress[AddrSCEV].push_back(RepR);
4186+
RecipesByAddress[AddrSCEV].push_back(RepR);
41664187
}
41674188
}
41684189

4169-
// For each address, collect loads with complementary masks, sort by
4170-
// dominance, and use the earliest load.
4171-
for (auto &[Addr, Loads] : LoadsByAddress) {
4172-
if (Loads.size() < 2)
4190+
// For each address, collect operations with the same or complementary masks.
4191+
SmallVector<SmallVector<VPReplicateRecipe *, 4>> AllGroups;
4192+
auto GetLoadStoreValueType = [&](VPReplicateRecipe *Recipe) {
4193+
return TypeInfo.inferScalarType(IsLoad ? Recipe : Recipe->getOperand(0));
4194+
};
4195+
for (auto &[Addr, Recipes] : RecipesByAddress) {
4196+
if (Recipes.size() < 2)
41734197
continue;
41744198

4175-
// Collect groups of loads with complementary masks.
4176-
SmallVector<SmallVector<VPReplicateRecipe *, 4>> LoadGroups;
4177-
for (VPReplicateRecipe *&LoadI : Loads) {
4178-
if (!LoadI)
4199+
// Collect groups with the same or complementary masks.
4200+
for (VPReplicateRecipe *&RecipeI : Recipes) {
4201+
if (!RecipeI)
41794202
continue;
41804203

4181-
VPValue *MaskI = LoadI->getMask();
4182-
Type *TypeI = TypeInfo.inferScalarType(LoadI);
4204+
VPValue *MaskI = RecipeI->getMask();
4205+
Type *TypeI = GetLoadStoreValueType(RecipeI);
41834206
SmallVector<VPReplicateRecipe *, 4> Group;
4184-
Group.push_back(LoadI);
4185-
LoadI = nullptr;
4207+
Group.push_back(RecipeI);
4208+
RecipeI = nullptr;
41864209

4187-
// Find all loads with the same type.
4188-
for (VPReplicateRecipe *&LoadJ : Loads) {
4189-
if (!LoadJ)
4210+
// Find all operations with the same or complementary masks.
4211+
bool HasComplementaryMask = false;
4212+
for (VPReplicateRecipe *&RecipeJ : Recipes) {
4213+
if (!RecipeJ)
41904214
continue;
41914215

4192-
Type *TypeJ = TypeInfo.inferScalarType(LoadJ);
4216+
VPValue *MaskJ = RecipeJ->getMask();
4217+
Type *TypeJ = GetLoadStoreValueType(RecipeJ);
41934218
if (TypeI == TypeJ) {
4194-
Group.push_back(LoadJ);
4195-
LoadJ = nullptr;
4219+
// Check if any operation in the group has a complementary mask with
4220+
// another, that is M1 == NOT(M2) or M2 == NOT(M1).
4221+
HasComplementaryMask |= match(MaskI, m_Not(m_Specific(MaskJ))) ||
4222+
match(MaskJ, m_Not(m_Specific(MaskI)));
4223+
Group.push_back(RecipeJ);
4224+
RecipeJ = nullptr;
41964225
}
41974226
}
41984227

4199-
// Check if any load in the group has a complementary mask with another,
4200-
// that is M1 == NOT(M2) or M2 == NOT(M1).
4201-
bool HasComplementaryMask =
4202-
any_of(drop_begin(Group), [MaskI](VPReplicateRecipe *Load) {
4203-
VPValue *MaskJ = Load->getMask();
4204-
return match(MaskI, m_Not(m_Specific(MaskJ))) ||
4205-
match(MaskJ, m_Not(m_Specific(MaskI)));
4206-
});
4228+
if (HasComplementaryMask) {
4229+
assert(Group.size() >= 2 && "must have at least 2 entries");
4230+
AllGroups.push_back(std::move(Group));
4231+
}
4232+
}
4233+
}
4234+
4235+
return AllGroups;
4236+
}
4237+
4238+
// Find the recipe with minimum alignment in the group.
4239+
template <typename InstType>
4240+
static VPReplicateRecipe *
4241+
findRecipeWithMinAlign(ArrayRef<VPReplicateRecipe *> Group) {
4242+
return *min_element(Group, [](VPReplicateRecipe *A, VPReplicateRecipe *B) {
4243+
return cast<InstType>(A->getUnderlyingInstr())->getAlign() <
4244+
cast<InstType>(B->getUnderlyingInstr())->getAlign();
4245+
});
4246+
}
4247+
4248+
void VPlanTransforms::hoistPredicatedLoads(VPlan &Plan, ScalarEvolution &SE,
4249+
const Loop *L) {
4250+
auto Groups =
4251+
collectComplementaryPredicatedMemOps<Instruction::Load>(Plan, SE, L);
4252+
if (Groups.empty())
4253+
return;
4254+
4255+
VPDominatorTree VPDT(Plan);
42074256

4208-
if (HasComplementaryMask)
4209-
LoadGroups.push_back(std::move(Group));
4257+
// Process each group of loads.
4258+
for (auto &Group : Groups) {
4259+
// Sort loads by dominance order, with earliest (most dominating) first.
4260+
sort(Group, [&VPDT](VPReplicateRecipe *A, VPReplicateRecipe *B) {
4261+
return VPDT.properlyDominates(A, B);
4262+
});
4263+
4264+
// Try to use the earliest (most dominating) load to replace all others.
4265+
VPReplicateRecipe *EarliestLoad = Group[0];
4266+
VPBasicBlock *FirstBB = EarliestLoad->getParent();
4267+
VPBasicBlock *LastBB = Group.back()->getParent();
4268+
4269+
// Check that the load doesn't alias with stores between first and last.
4270+
auto LoadLoc = vputils::getMemoryLocation(*EarliestLoad);
4271+
if (!LoadLoc || !canHoistOrSinkWithNoAliasCheck(*LoadLoc, FirstBB, LastBB,
4272+
/*CheckReads=*/false))
4273+
continue;
4274+
4275+
// Collect common metadata from all loads in the group.
4276+
VPIRMetadata CommonMetadata = getCommonMetadata(Group);
4277+
4278+
// Find the load with minimum alignment to use.
4279+
auto *LoadWithMinAlign = findRecipeWithMinAlign<LoadInst>(Group);
4280+
4281+
// Create an unpredicated version of the earliest load with common
4282+
// metadata.
4283+
auto *UnpredicatedLoad = new VPReplicateRecipe(
4284+
LoadWithMinAlign->getUnderlyingInstr(), {EarliestLoad->getOperand(0)},
4285+
/*IsSingleScalar=*/false, /*Mask=*/nullptr, *EarliestLoad,
4286+
CommonMetadata);
4287+
4288+
UnpredicatedLoad->insertBefore(EarliestLoad);
4289+
4290+
// Replace all loads in the group with the unpredicated load.
4291+
for (VPReplicateRecipe *Load : Group) {
4292+
Load->replaceAllUsesWith(UnpredicatedLoad);
4293+
Load->eraseFromParent();
42104294
}
4295+
}
4296+
}
42114297

4212-
// For each group, check memory dependencies and hoist the earliest load.
4213-
for (auto &Group : LoadGroups) {
4214-
// Sort loads by dominance order, with earliest (most dominating) first.
4215-
sort(Group, [&VPDT](VPReplicateRecipe *A, VPReplicateRecipe *B) {
4216-
return VPDT.properlyDominates(A, B);
4217-
});
4298+
static bool
4299+
canSinkStoreWithNoAliasCheck(ArrayRef<VPReplicateRecipe *> StoresToSink) {
4300+
auto StoreLoc = vputils::getMemoryLocation(*StoresToSink.front());
4301+
if (!StoreLoc || !StoreLoc->AATags.Scope)
4302+
return false;
42184303

4219-
VPReplicateRecipe *EarliestLoad = Group.front();
4220-
VPBasicBlock *FirstBB = EarliestLoad->getParent();
4221-
VPBasicBlock *LastBB = Group.back()->getParent();
4304+
// When sinking a group of stores, all members of the group alias each other.
4305+
// Skip them during the alias checks.
4306+
SmallPtrSet<VPRecipeBase *, 4> StoresToSinkSet(StoresToSink.begin(),
4307+
StoresToSink.end());
42224308

4223-
// Check that the load doesn't alias with stores between first and last.
4224-
if (!canHoistLoadWithNoAliasCheck(EarliestLoad, FirstBB, LastBB))
4225-
continue;
4309+
VPBasicBlock *FirstBB = StoresToSink.front()->getParent();
4310+
VPBasicBlock *LastBB = StoresToSink.back()->getParent();
4311+
return canHoistOrSinkWithNoAliasCheck(*StoreLoc, FirstBB, LastBB,
4312+
/*CheckReads=*/true, &StoresToSinkSet);
4313+
}
42264314

4227-
// Find the load with minimum alignment to use.
4228-
auto *LoadWithMinAlign =
4229-
*min_element(Group, [](VPReplicateRecipe *A, VPReplicateRecipe *B) {
4230-
return cast<LoadInst>(A->getUnderlyingInstr())->getAlign() <
4231-
cast<LoadInst>(B->getUnderlyingInstr())->getAlign();
4232-
});
4315+
void VPlanTransforms::sinkPredicatedStores(VPlan &Plan, ScalarEvolution &SE,
4316+
const Loop *L) {
4317+
auto Groups =
4318+
collectComplementaryPredicatedMemOps<Instruction::Store>(Plan, SE, L);
4319+
if (Groups.empty())
4320+
return;
42334321

4234-
// Collect common metadata from all loads in the group.
4235-
VPIRMetadata CommonMetadata = getCommonLoadMetadata(Group);
4236-
4237-
// Create an unpredicated load with minimum alignment using the earliest
4238-
// dominating address and common metadata.
4239-
auto *UnpredicatedLoad = new VPReplicateRecipe(
4240-
LoadWithMinAlign->getUnderlyingInstr(), EarliestLoad->getOperand(0),
4241-
/*IsSingleScalar=*/false, /*Mask=*/nullptr, /*Flags=*/{},
4242-
CommonMetadata);
4243-
UnpredicatedLoad->insertBefore(EarliestLoad);
4244-
4245-
// Replace all loads in the group with the unpredicated load.
4246-
for (VPReplicateRecipe *Load : Group) {
4247-
Load->replaceAllUsesWith(UnpredicatedLoad);
4248-
Load->eraseFromParent();
4249-
}
4322+
VPDominatorTree VPDT(Plan);
4323+
4324+
for (auto &Group : Groups) {
4325+
sort(Group, [&VPDT](VPReplicateRecipe *A, VPReplicateRecipe *B) {
4326+
return VPDT.properlyDominates(A, B);
4327+
});
4328+
4329+
if (!canSinkStoreWithNoAliasCheck(Group))
4330+
continue;
4331+
4332+
// Use the last (most dominated) store's location for the unconditional
4333+
// store.
4334+
VPReplicateRecipe *LastStore = Group.back();
4335+
VPBasicBlock *InsertBB = LastStore->getParent();
4336+
4337+
// Collect common alias metadata from all stores in the group.
4338+
VPIRMetadata CommonMetadata = getCommonMetadata(Group);
4339+
4340+
// Build select chain for stored values.
4341+
VPValue *SelectedValue = Group[0]->getOperand(0);
4342+
VPBuilder Builder(InsertBB, LastStore->getIterator());
4343+
4344+
for (unsigned I = 1; I < Group.size(); ++I) {
4345+
VPValue *Mask = Group[I]->getMask();
4346+
VPValue *Value = Group[I]->getOperand(0);
4347+
SelectedValue = Builder.createSelect(Mask, Value, SelectedValue,
4348+
Group[I]->getDebugLoc());
42504349
}
4350+
4351+
// Find the store with minimum alignment to use.
4352+
auto *StoreWithMinAlign = findRecipeWithMinAlign<StoreInst>(Group);
4353+
4354+
// Create unconditional store with selected value and common metadata.
4355+
auto *UnpredicatedStore =
4356+
new VPReplicateRecipe(StoreWithMinAlign->getUnderlyingInstr(),
4357+
{SelectedValue, LastStore->getOperand(1)},
4358+
/*IsSingleScalar=*/false,
4359+
/*Mask=*/nullptr, *LastStore, CommonMetadata);
4360+
UnpredicatedStore->insertBefore(*InsertBB, LastStore->getIterator());
4361+
4362+
// Remove all predicated stores from the group.
4363+
for (VPReplicateRecipe *Store : Group)
4364+
Store->eraseFromParent();
42514365
}
42524366
}
42534367

llvm/lib/Transforms/Vectorize/VPlanTransforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,13 @@ struct VPlanTransforms {
325325
static void hoistPredicatedLoads(VPlan &Plan, ScalarEvolution &SE,
326326
const Loop *L);
327327

328+
/// Sink predicated stores to the same address with complementary predicates
329+
/// (P and NOT P) to an unconditional store with select recipes for the
330+
/// stored values. This eliminates branching overhead when all paths
331+
/// unconditionally store to the same location.
332+
static void sinkPredicatedStores(VPlan &Plan, ScalarEvolution &SE,
333+
const Loop *L);
334+
328335
// Materialize vector trip counts for constants early if it can simply be
329336
// computed as (Original TC / VF * UF) * VF * UF.
330337
static void

0 commit comments

Comments
 (0)