Skip to content

Commit 0946ab3

Browse files
committed
[VPlan] Sink predicated stores with complementary masks. (llvm#168771)
Extend the logic to hoist predicated loads (llvm#168373) to sink predicated stores with complementary masks in a similar fashion. The patch refactors some of the existing logic for legality checks to be shared between hosting and sinking, and adds a new sinking transform on top. With respect to the legality checks, for sinking stores the code also checks if there are any aliasing stores that may alias, not only loads. PR: llvm#168771 (cherry picked from commit 4b6ad11)
1 parent cf7073b commit 0946ab3

File tree

5 files changed

+360
-290
lines changed

5 files changed

+360
-290
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8339,6 +8339,7 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
83398339
bool HasScalarVF = Plan->hasScalarVFOnly();
83408340
// Now optimize the initial VPlan.
83418341
VPlanTransforms::hoistPredicatedLoads(*Plan, *PSE.getSE(), OrigLoop);
8342+
VPlanTransforms::sinkPredicatedStores(*Plan, *PSE.getSE(), OrigLoop);
83428343
if (!HasScalarVF)
83438344
VPlanTransforms::runPass(VPlanTransforms::truncateToMinimalBitwidths,
83448345
*Plan, CM.getMinimalBitwidths());

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 212 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -130,35 +130,51 @@ bool VPlanTransforms::tryToConvertVPInstructionsToVPRecipes(
130130
return true;
131131
}
132132

133-
// Check if a load can be hoisted by verifying it doesn't alias with any stores
134-
// in blocks between FirstBB and LastBB using scoped noalias metadata.
135-
static bool canHoistLoadWithNoAliasCheck(VPReplicateRecipe *Load,
136-
VPBasicBlock *FirstBB,
137-
VPBasicBlock *LastBB) {
138-
// Get the load's memory location and check if it aliases with any stores
139-
// using scoped noalias metadata.
140-
auto LoadLoc = vputils::getMemoryLocation(*Load);
141-
if (!LoadLoc || !LoadLoc->AATags.Scope)
133+
// Check if a memory operation doesn't alias with memory operations in blocks
134+
// between FirstBB and LastBB using scoped noalias metadata.
135+
// For load hoisting, we only check writes in one direction.
136+
// For store sinking, we check both reads and writes bidirectionally.
137+
static bool canHoistOrSinkWithNoAliasCheck(
138+
const MemoryLocation &MemLoc, VPBasicBlock *FirstBB, VPBasicBlock *LastBB,
139+
bool CheckReads,
140+
const SmallPtrSetImpl<VPRecipeBase *> *ExcludeRecipes = nullptr) {
141+
if (!MemLoc.AATags.Scope)
142142
return false;
143143

144-
const AAMDNodes &LoadAA = LoadLoc->AATags;
144+
const AAMDNodes &MemAA = MemLoc.AATags;
145+
145146
for (VPBlockBase *Block = FirstBB; Block;
146147
Block = Block->getSingleSuccessor()) {
147-
// This function assumes a simple linear chain of blocks. If there are
148-
// multiple successors, we would need more complex analysis.
149148
assert(Block->getNumSuccessors() <= 1 &&
150149
"Expected at most one successor in block chain");
151150
auto *VPBB = cast<VPBasicBlock>(Block);
152151
for (VPRecipeBase &R : *VPBB) {
153-
if (R.mayWriteToMemory()) {
154-
auto Loc = vputils::getMemoryLocation(R);
155-
// Bail out if we can't get the location or if the scoped noalias
156-
// metadata indicates potential aliasing.
157-
if (!Loc || ScopedNoAliasAAResult::mayAliasInScopes(
158-
LoadAA.Scope, Loc->AATags.NoAlias))
159-
return false;
160-
}
152+
if (ExcludeRecipes && ExcludeRecipes->contains(&R))
153+
continue;
154+
155+
// Skip recipes that don't need checking.
156+
if (!R.mayWriteToMemory() && !(CheckReads && R.mayReadFromMemory()))
157+
continue;
158+
159+
auto Loc = vputils::getMemoryLocation(R);
160+
if (!Loc)
161+
// Conservatively assume aliasing for memory operations without
162+
// location.
163+
return false;
164+
165+
// For reads, check if they don't alias in the reverse direction and
166+
// skip if so.
167+
if (CheckReads && R.mayReadFromMemory() &&
168+
!ScopedNoAliasAAResult::mayAliasInScopes(Loc->AATags.Scope,
169+
MemAA.NoAlias))
170+
continue;
171+
172+
// Check if the memory operations may alias in the forward direction.
173+
if (ScopedNoAliasAAResult::mayAliasInScopes(MemAA.Scope,
174+
Loc->AATags.NoAlias))
175+
return false;
161176
}
177+
162178
if (Block == LastBB)
163179
break;
164180
}
@@ -3187,123 +3203,220 @@ void VPlanTransforms::hoistInvariantLoads(VPlan &Plan) {
31873203
}
31883204
}
31893205

3190-
// Returns the intersection of metadata from a group of loads.
3191-
static VPIRMetadata getCommonLoadMetadata(ArrayRef<VPReplicateRecipe *> Loads) {
3192-
VPIRMetadata CommonMetadata = *Loads.front();
3193-
for (VPReplicateRecipe *Load : drop_begin(Loads))
3194-
CommonMetadata.intersect(*Load);
3206+
// Collect common metadata from a group of replicate recipes by intersecting
3207+
// metadata from all recipes in the group.
3208+
static VPIRMetadata getCommonMetadata(ArrayRef<VPReplicateRecipe *> Recipes) {
3209+
VPIRMetadata CommonMetadata = *Recipes.front();
3210+
for (VPReplicateRecipe *Recipe : drop_begin(Recipes))
3211+
CommonMetadata.intersect(*Recipe);
31953212
return CommonMetadata;
31963213
}
31973214

3198-
void VPlanTransforms::hoistPredicatedLoads(VPlan &Plan, ScalarEvolution &SE,
3199-
const Loop *L) {
3200-
using namespace VPlanPatternMatch;
3215+
template <unsigned Opcode>
3216+
static SmallVector<SmallVector<VPReplicateRecipe *, 4>>
3217+
collectComplementaryPredicatedMemOps(VPlan &Plan, ScalarEvolution &SE,
3218+
const Loop *L) {
3219+
using namespace llvm::VPlanPatternMatch;
3220+
static_assert(Opcode == Instruction::Load || Opcode == Instruction::Store,
3221+
"Only Load and Store opcodes supported");
3222+
constexpr bool IsLoad = (Opcode == Instruction::Load);
32013223
VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion();
32023224
VPTypeAnalysis TypeInfo(Plan);
3203-
VPDominatorTree VPDT(Plan);
32043225

3205-
// Group predicated loads by their address SCEV.
3206-
DenseMap<const SCEV *, SmallVector<VPReplicateRecipe *>> LoadsByAddress;
3226+
// Group predicated operations by their address SCEV.
3227+
DenseMap<const SCEV *, SmallVector<VPReplicateRecipe *>> RecipesByAddress;
32073228
for (VPBlockBase *Block : vp_depth_first_shallow(LoopRegion->getEntry())) {
32083229
auto *VPBB = cast<VPBasicBlock>(Block);
32093230
for (VPRecipeBase &R : *VPBB) {
32103231
auto *RepR = dyn_cast<VPReplicateRecipe>(&R);
3211-
if (!RepR || RepR->getOpcode() != Instruction::Load ||
3212-
!RepR->isPredicated())
3232+
if (!RepR || RepR->getOpcode() != Opcode || !RepR->isPredicated())
32133233
continue;
32143234

3215-
VPValue *Addr = RepR->getOperand(0);
3235+
// For loads, operand 0 is address; for stores, operand 1 is address.
3236+
VPValue *Addr = RepR->getOperand(IsLoad ? 0 : 1);
32163237
const SCEV *AddrSCEV = vputils::getSCEVExprForVPValue(Addr, SE, L);
32173238
if (!isa<SCEVCouldNotCompute>(AddrSCEV))
3218-
LoadsByAddress[AddrSCEV].push_back(RepR);
3239+
RecipesByAddress[AddrSCEV].push_back(RepR);
32193240
}
32203241
}
32213242

3222-
// For each address, collect loads with complementary masks, sort by
3223-
// dominance, and use the earliest load.
3224-
for (auto &[Addr, Loads] : LoadsByAddress) {
3225-
if (Loads.size() < 2)
3243+
// For each address, collect operations with the same or complementary masks.
3244+
SmallVector<SmallVector<VPReplicateRecipe *, 4>> AllGroups;
3245+
auto GetLoadStoreValueType = [&](VPReplicateRecipe *Recipe) {
3246+
return TypeInfo.inferScalarType(IsLoad ? Recipe : Recipe->getOperand(0));
3247+
};
3248+
for (auto &[Addr, Recipes] : RecipesByAddress) {
3249+
if (Recipes.size() < 2)
32263250
continue;
32273251

3228-
// Collect groups of loads with complementary masks.
3229-
SmallVector<SmallVector<VPReplicateRecipe *, 4>> LoadGroups;
3230-
for (VPReplicateRecipe *&LoadI : Loads) {
3231-
if (!LoadI)
3252+
// Collect groups with the same or complementary masks.
3253+
for (VPReplicateRecipe *&RecipeI : Recipes) {
3254+
if (!RecipeI)
32323255
continue;
32333256

3234-
VPValue *MaskI = LoadI->getMask();
3235-
Type *TypeI = TypeInfo.inferScalarType(LoadI);
3257+
VPValue *MaskI = RecipeI->getMask();
3258+
Type *TypeI = GetLoadStoreValueType(RecipeI);
32363259
SmallVector<VPReplicateRecipe *, 4> Group;
3237-
Group.push_back(LoadI);
3238-
LoadI = nullptr;
3260+
Group.push_back(RecipeI);
3261+
RecipeI = nullptr;
32393262

3240-
// Find all loads with the same type.
3241-
for (VPReplicateRecipe *&LoadJ : Loads) {
3242-
if (!LoadJ)
3263+
// Find all operations with the same or complementary masks.
3264+
bool HasComplementaryMask = false;
3265+
for (VPReplicateRecipe *&RecipeJ : Recipes) {
3266+
if (!RecipeJ)
32433267
continue;
32443268

3245-
Type *TypeJ = TypeInfo.inferScalarType(LoadJ);
3269+
VPValue *MaskJ = RecipeJ->getMask();
3270+
Type *TypeJ = GetLoadStoreValueType(RecipeJ);
32463271
if (TypeI == TypeJ) {
3247-
Group.push_back(LoadJ);
3248-
LoadJ = nullptr;
3272+
// Check if any operation in the group has a complementary mask with
3273+
// another, that is M1 == NOT(M2) or M2 == NOT(M1).
3274+
HasComplementaryMask |= match(MaskI, m_Not(m_Specific(MaskJ))) ||
3275+
match(MaskJ, m_Not(m_Specific(MaskI)));
3276+
Group.push_back(RecipeJ);
3277+
RecipeJ = nullptr;
32493278
}
32503279
}
32513280

3252-
// Check if any load in the group has a complementary mask with another,
3253-
// that is M1 == NOT(M2) or M2 == NOT(M1).
3254-
bool HasComplementaryMask =
3255-
any_of(drop_begin(Group), [MaskI](VPReplicateRecipe *Load) {
3256-
VPValue *MaskJ = Load->getMask();
3257-
return match(MaskI, m_Not(m_Specific(MaskJ))) ||
3258-
match(MaskJ, m_Not(m_Specific(MaskI)));
3259-
});
3260-
3261-
if (HasComplementaryMask)
3262-
LoadGroups.push_back(std::move(Group));
3281+
if (HasComplementaryMask) {
3282+
assert(Group.size() >= 2 && "must have at least 2 entries");
3283+
AllGroups.push_back(std::move(Group));
3284+
}
32633285
}
3286+
}
32643287

3265-
// For each group, check memory dependencies and hoist the earliest load.
3266-
for (auto &Group : LoadGroups) {
3267-
// Sort loads by dominance order, with earliest (most dominating) first.
3268-
sort(Group, [&VPDT](VPReplicateRecipe *A, VPReplicateRecipe *B) {
3269-
return VPDT.properlyDominates(A, B);
3270-
});
3288+
return AllGroups;
3289+
}
32713290

3272-
VPReplicateRecipe *EarliestLoad = Group.front();
3273-
VPBasicBlock *FirstBB = EarliestLoad->getParent();
3274-
VPBasicBlock *LastBB = Group.back()->getParent();
3291+
// Find the recipe with minimum alignment in the group.
3292+
template <typename InstType>
3293+
static VPReplicateRecipe *
3294+
findRecipeWithMinAlign(ArrayRef<VPReplicateRecipe *> Group) {
3295+
return *min_element(Group, [](VPReplicateRecipe *A, VPReplicateRecipe *B) {
3296+
return cast<InstType>(A->getUnderlyingInstr())->getAlign() <
3297+
cast<InstType>(B->getUnderlyingInstr())->getAlign();
3298+
});
3299+
}
32753300

3276-
// Check that the load doesn't alias with stores between first and last.
3277-
if (!canHoistLoadWithNoAliasCheck(EarliestLoad, FirstBB, LastBB))
3278-
continue;
3301+
void VPlanTransforms::hoistPredicatedLoads(VPlan &Plan, ScalarEvolution &SE,
3302+
const Loop *L) {
3303+
auto Groups =
3304+
collectComplementaryPredicatedMemOps<Instruction::Load>(Plan, SE, L);
3305+
if (Groups.empty())
3306+
return;
32793307

3280-
// Find the load with minimum alignment to use.
3281-
auto *LoadWithMinAlign =
3282-
*min_element(Group, [](VPReplicateRecipe *A, VPReplicateRecipe *B) {
3283-
return cast<LoadInst>(A->getUnderlyingInstr())->getAlign() <
3284-
cast<LoadInst>(B->getUnderlyingInstr())->getAlign();
3285-
});
3286-
3287-
// Collect common metadata from all loads in the group.
3288-
VPIRMetadata CommonMetadata = getCommonLoadMetadata(Group);
3289-
3290-
// Create an unpredicated load with minimum alignment using the earliest
3291-
// dominating address and common metadata.
3292-
auto *UnpredicatedLoad = new VPReplicateRecipe(
3293-
LoadWithMinAlign->getUnderlyingInstr(), EarliestLoad->getOperand(0),
3294-
/*IsSingleScalar=*/false, /*Mask=*/nullptr,
3295-
CommonMetadata);
3296-
UnpredicatedLoad->insertBefore(EarliestLoad);
3297-
3298-
// Replace all loads in the group with the unpredicated load.
3299-
for (VPReplicateRecipe *Load : Group) {
3300-
Load->replaceAllUsesWith(UnpredicatedLoad);
3301-
Load->eraseFromParent();
3302-
}
3308+
VPDominatorTree VPDT(Plan);
3309+
3310+
// Process each group of loads.
3311+
for (auto &Group : Groups) {
3312+
// Sort loads by dominance order, with earliest (most dominating) first.
3313+
sort(Group, [&VPDT](VPReplicateRecipe *A, VPReplicateRecipe *B) {
3314+
return VPDT.properlyDominates(A, B);
3315+
});
3316+
3317+
// Try to use the earliest (most dominating) load to replace all others.
3318+
VPReplicateRecipe *EarliestLoad = Group[0];
3319+
VPBasicBlock *FirstBB = EarliestLoad->getParent();
3320+
VPBasicBlock *LastBB = Group.back()->getParent();
3321+
3322+
// Check that the load doesn't alias with stores between first and last.
3323+
auto LoadLoc = vputils::getMemoryLocation(*EarliestLoad);
3324+
if (!LoadLoc || !canHoistOrSinkWithNoAliasCheck(*LoadLoc, FirstBB, LastBB,
3325+
/*CheckReads=*/false))
3326+
continue;
3327+
3328+
// Collect common metadata from all loads in the group.
3329+
VPIRMetadata CommonMetadata = getCommonMetadata(Group);
3330+
3331+
// Find the load with minimum alignment to use.
3332+
auto *LoadWithMinAlign = findRecipeWithMinAlign<LoadInst>(Group);
3333+
3334+
// Create an unpredicated version of the earliest load with common
3335+
// metadata.
3336+
auto *UnpredicatedLoad = new VPReplicateRecipe(
3337+
LoadWithMinAlign->getUnderlyingInstr(), {EarliestLoad->getOperand(0)},
3338+
/*IsSingleScalar=*/false, /*Mask=*/nullptr, CommonMetadata);
3339+
3340+
UnpredicatedLoad->insertBefore(EarliestLoad);
3341+
3342+
// Replace all loads in the group with the unpredicated load.
3343+
for (VPReplicateRecipe *Load : Group) {
3344+
Load->replaceAllUsesWith(UnpredicatedLoad);
3345+
Load->eraseFromParent();
33033346
}
33043347
}
33053348
}
33063349

3350+
static bool
3351+
canSinkStoreWithNoAliasCheck(ArrayRef<VPReplicateRecipe *> StoresToSink) {
3352+
auto StoreLoc = vputils::getMemoryLocation(*StoresToSink.front());
3353+
if (!StoreLoc || !StoreLoc->AATags.Scope)
3354+
return false;
3355+
3356+
// When sinking a group of stores, all members of the group alias each other.
3357+
// Skip them during the alias checks.
3358+
SmallPtrSet<VPRecipeBase *, 4> StoresToSinkSet(StoresToSink.begin(),
3359+
StoresToSink.end());
3360+
3361+
VPBasicBlock *FirstBB = StoresToSink.front()->getParent();
3362+
VPBasicBlock *LastBB = StoresToSink.back()->getParent();
3363+
return canHoistOrSinkWithNoAliasCheck(*StoreLoc, FirstBB, LastBB,
3364+
/*CheckReads=*/true, &StoresToSinkSet);
3365+
}
3366+
3367+
void VPlanTransforms::sinkPredicatedStores(VPlan &Plan, ScalarEvolution &SE,
3368+
const Loop *L) {
3369+
auto Groups =
3370+
collectComplementaryPredicatedMemOps<Instruction::Store>(Plan, SE, L);
3371+
if (Groups.empty())
3372+
return;
3373+
3374+
VPDominatorTree VPDT(Plan);
3375+
3376+
for (auto &Group : Groups) {
3377+
sort(Group, [&VPDT](VPReplicateRecipe *A, VPReplicateRecipe *B) {
3378+
return VPDT.properlyDominates(A, B);
3379+
});
3380+
3381+
if (!canSinkStoreWithNoAliasCheck(Group))
3382+
continue;
3383+
3384+
// Use the last (most dominated) store's location for the unconditional
3385+
// store.
3386+
VPReplicateRecipe *LastStore = Group.back();
3387+
VPBasicBlock *InsertBB = LastStore->getParent();
3388+
3389+
// Collect common alias metadata from all stores in the group.
3390+
VPIRMetadata CommonMetadata = getCommonMetadata(Group);
3391+
3392+
// Build select chain for stored values.
3393+
VPValue *SelectedValue = Group[0]->getOperand(0);
3394+
VPBuilder Builder(InsertBB, LastStore->getIterator());
3395+
3396+
for (unsigned I = 1; I < Group.size(); ++I) {
3397+
VPValue *Mask = Group[I]->getMask();
3398+
VPValue *Value = Group[I]->getOperand(0);
3399+
SelectedValue = Builder.createSelect(Mask, Value, SelectedValue,
3400+
Group[I]->getDebugLoc());
3401+
}
3402+
3403+
// Find the store with minimum alignment to use.
3404+
auto *StoreWithMinAlign = findRecipeWithMinAlign<StoreInst>(Group);
3405+
3406+
// Create unconditional store with selected value and common metadata.
3407+
auto *UnpredicatedStore =
3408+
new VPReplicateRecipe(StoreWithMinAlign->getUnderlyingInstr(),
3409+
{SelectedValue, LastStore->getOperand(1)},
3410+
/*IsSingleScalar=*/false,
3411+
/*Mask=*/nullptr, CommonMetadata);
3412+
UnpredicatedStore->insertBefore(*InsertBB, LastStore->getIterator());
3413+
3414+
// Remove all predicated stores from the group.
3415+
for (VPReplicateRecipe *Store : Group)
3416+
Store->eraseFromParent();
3417+
}
3418+
}
3419+
33073420
/// Returns true if \p V is VPWidenLoadRecipe or VPInterleaveRecipe that can be
33083421
/// converted to a narrower recipe. \p V is used by a wide recipe that feeds a
33093422
/// store interleave group at index \p Idx, \p WideMember0 is the recipe feeding

llvm/lib/Transforms/Vectorize/VPlanTransforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,13 @@ struct VPlanTransforms {
241241
static void hoistPredicatedLoads(VPlan &Plan, ScalarEvolution &SE,
242242
const Loop *L);
243243

244+
/// Sink predicated stores to the same address with complementary predicates
245+
/// (P and NOT P) to an unconditional store with select recipes for the
246+
/// stored values. This eliminates branching overhead when all paths
247+
/// unconditionally store to the same location.
248+
static void sinkPredicatedStores(VPlan &Plan, ScalarEvolution &SE,
249+
const Loop *L);
250+
244251
/// Try to convert a plan with interleave groups with VF elements to a plan
245252
/// with the interleave groups replaced by wide loads and stores processing VF
246253
/// elements, if all transformed interleave groups access the full vector

0 commit comments

Comments
 (0)