Skip to content

Commit e2bbd6d

Browse files
PeddleSpamLeon Clark
andauthored
[VectorCombine][AMDGPU] Narrow Phi of Shuffles. (#140188)
Attempt to narrow a phi of shufflevector instructions where the two incoming values have the same operands but different masks. Related to #128938. --------- Co-authored-by: Leon Clark <[email protected]>
1 parent c1f797e commit e2bbd6d

File tree

8 files changed

+9204
-0
lines changed

8 files changed

+9204
-0
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ class VectorCombine {
142142
bool foldInterleaveIntrinsics(Instruction &I);
143143
bool shrinkType(Instruction &I);
144144
bool shrinkLoadForShuffles(Instruction &I);
145+
bool shrinkPhiOfShuffles(Instruction &I);
145146

146147
void replaceValue(Value &Old, Value &New) {
147148
LLVM_DEBUG(dbgs() << "VC: Replacing: " << Old << '\n');
@@ -3994,6 +3995,101 @@ bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
39943995
return false;
39953996
}
39963997

3998+
// Attempt to narrow a phi of shufflevector instructions where the two incoming
3999+
// values have the same operands but different masks. If the two shuffle masks
4000+
// are offsets of one another we can use one branch to rotate the incoming
4001+
// vector and perform one larger shuffle after the phi.
4002+
bool VectorCombine::shrinkPhiOfShuffles(Instruction &I) {
4003+
auto *Phi = dyn_cast<PHINode>(&I);
4004+
if (!Phi || Phi->getNumIncomingValues() != 2u)
4005+
return false;
4006+
4007+
Value *Op = nullptr;
4008+
ArrayRef<int> Mask0;
4009+
ArrayRef<int> Mask1;
4010+
4011+
if (!match(Phi->getOperand(0u),
4012+
m_OneUse(m_Shuffle(m_Value(Op), m_Poison(), m_Mask(Mask0)))) ||
4013+
!match(Phi->getOperand(1u),
4014+
m_OneUse(m_Shuffle(m_Specific(Op), m_Poison(), m_Mask(Mask1)))))
4015+
return false;
4016+
4017+
auto *Shuf = cast<ShuffleVectorInst>(Phi->getOperand(0u));
4018+
4019+
// Ensure result vectors are wider than the argument vector.
4020+
auto *InputVT = cast<FixedVectorType>(Op->getType());
4021+
auto *ResultVT = cast<FixedVectorType>(Shuf->getType());
4022+
auto const InputNumElements = InputVT->getNumElements();
4023+
4024+
if (InputNumElements >= ResultVT->getNumElements())
4025+
return false;
4026+
4027+
// Take the difference of the two shuffle masks at each index. Ignore poison
4028+
// values at the same index in both masks.
4029+
SmallVector<int, 16> NewMask;
4030+
NewMask.reserve(Mask0.size());
4031+
4032+
for (auto [M0, M1] : zip(Mask0, Mask1)) {
4033+
if (M0 >= 0 && M1 >= 0)
4034+
NewMask.push_back(M0 - M1);
4035+
else if (M0 == -1 && M1 == -1)
4036+
continue;
4037+
else
4038+
return false;
4039+
}
4040+
4041+
// Ensure all elements of the new mask are equal. If the difference between
4042+
// the incoming mask elements is the same, the two must be constant offsets
4043+
// of one another.
4044+
if (NewMask.empty() || !all_equal(NewMask))
4045+
return false;
4046+
4047+
// Create new mask using difference of the two incoming masks.
4048+
int MaskOffset = NewMask[0u];
4049+
unsigned Index = (InputNumElements - MaskOffset) % InputNumElements;
4050+
NewMask.clear();
4051+
4052+
for (unsigned I = 0u; I < InputNumElements; ++I) {
4053+
NewMask.push_back(Index);
4054+
Index = (Index + 1u) % InputNumElements;
4055+
}
4056+
4057+
// Calculate costs for worst cases and compare.
4058+
auto const Kind = TTI::SK_PermuteSingleSrc;
4059+
auto OldCost =
4060+
std::max(TTI.getShuffleCost(Kind, ResultVT, InputVT, Mask0, CostKind),
4061+
TTI.getShuffleCost(Kind, ResultVT, InputVT, Mask1, CostKind));
4062+
auto NewCost = TTI.getShuffleCost(Kind, InputVT, InputVT, NewMask, CostKind) +
4063+
TTI.getShuffleCost(Kind, ResultVT, InputVT, Mask1, CostKind);
4064+
4065+
LLVM_DEBUG(dbgs() << "Found a phi of mergeable shuffles: " << I
4066+
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
4067+
<< "\n");
4068+
4069+
if (NewCost > OldCost)
4070+
return false;
4071+
4072+
// Create new shuffles and narrowed phi.
4073+
auto Builder = IRBuilder(Shuf);
4074+
Builder.SetCurrentDebugLocation(Shuf->getDebugLoc());
4075+
auto *PoisonVal = PoisonValue::get(InputVT);
4076+
auto *NewShuf0 = Builder.CreateShuffleVector(Op, PoisonVal, NewMask);
4077+
Worklist.push(cast<Instruction>(NewShuf0));
4078+
4079+
Builder.SetInsertPoint(Phi);
4080+
Builder.SetCurrentDebugLocation(Phi->getDebugLoc());
4081+
auto *NewPhi = Builder.CreatePHI(NewShuf0->getType(), 2u);
4082+
NewPhi->addIncoming(NewShuf0, Phi->getIncomingBlock(0u));
4083+
NewPhi->addIncoming(Op, Phi->getIncomingBlock(1u));
4084+
4085+
Builder.SetInsertPoint(*NewPhi->getInsertionPointAfterDef());
4086+
PoisonVal = PoisonValue::get(NewPhi->getType());
4087+
auto *NewShuf1 = Builder.CreateShuffleVector(NewPhi, PoisonVal, Mask1);
4088+
4089+
replaceValue(*Phi, *NewShuf1);
4090+
return true;
4091+
}
4092+
39974093
/// This is the entry point for all transforms. Pass manager differences are
39984094
/// handled in the callers of this function.
39994095
bool VectorCombine::run() {
@@ -4081,6 +4177,9 @@ bool VectorCombine::run() {
40814177
case Instruction::Xor:
40824178
MadeChange |= foldBitOpOfCastops(I);
40834179
break;
4180+
case Instruction::PHI:
4181+
MadeChange |= shrinkPhiOfShuffles(I);
4182+
break;
40844183
default:
40854184
MadeChange |= shrinkType(I);
40864185
break;

0 commit comments

Comments
 (0)