Skip to content

Commit 766c90f

Browse files
authored
[VectorCombine] foldShuffleOfCastops - handle unary shuffles (#160009)
Fixes #156853.
1 parent 31818fb commit 766c90f

File tree

4 files changed

+177
-98
lines changed

4 files changed

+177
-98
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2487,21 +2487,31 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
24872487
if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
24882488
return false;
24892489

2490+
// Check whether this is a binary shuffle.
2491+
bool IsBinaryShuffle = !isa<UndefValue>(V1);
2492+
24902493
auto *C0 = dyn_cast<CastInst>(V0);
24912494
auto *C1 = dyn_cast<CastInst>(V1);
2492-
if (!C0 || !C1)
2495+
if (!C0 || (IsBinaryShuffle && !C1))
24932496
return false;
24942497

24952498
Instruction::CastOps Opcode = C0->getOpcode();
2496-
if (C0->getSrcTy() != C1->getSrcTy())
2499+
2500+
// If this is allowed, foldShuffleOfCastops can get stuck in a loop
2501+
// with foldBitcastOfShuffle. Reject in favor of foldBitcastOfShuffle.
2502+
if (!IsBinaryShuffle && Opcode == Instruction::BitCast)
24972503
return false;
24982504

2499-
// Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2500-
if (Opcode != C1->getOpcode()) {
2501-
if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
2502-
Opcode = Instruction::SExt;
2503-
else
2505+
if (IsBinaryShuffle) {
2506+
if (C0->getSrcTy() != C1->getSrcTy())
25042507
return false;
2508+
// Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2509+
if (Opcode != C1->getOpcode()) {
2510+
if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
2511+
Opcode = Instruction::SExt;
2512+
else
2513+
return false;
2514+
}
25052515
}
25062516

25072517
auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
@@ -2544,38 +2554,52 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
25442554
InstructionCost CostC0 =
25452555
TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
25462556
TTI::CastContextHint::None, CostKind);
2547-
InstructionCost CostC1 =
2548-
TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
2549-
TTI::CastContextHint::None, CostKind);
2550-
InstructionCost OldCost = CostC0 + CostC1;
2551-
OldCost +=
2552-
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleDstTy,
2553-
CastDstTy, OldMask, CostKind, 0, nullptr, {}, &I);
25542557

2555-
InstructionCost NewCost =
2556-
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, NewShuffleDstTy,
2557-
CastSrcTy, NewMask, CostKind);
2558+
TargetTransformInfo::ShuffleKind ShuffleKind;
2559+
if (IsBinaryShuffle)
2560+
ShuffleKind = TargetTransformInfo::SK_PermuteTwoSrc;
2561+
else
2562+
ShuffleKind = TargetTransformInfo::SK_PermuteSingleSrc;
2563+
2564+
InstructionCost OldCost = CostC0;
2565+
OldCost += TTI.getShuffleCost(ShuffleKind, ShuffleDstTy, CastDstTy, OldMask,
2566+
CostKind, 0, nullptr, {}, &I);
2567+
2568+
InstructionCost NewCost = TTI.getShuffleCost(ShuffleKind, NewShuffleDstTy,
2569+
CastSrcTy, NewMask, CostKind);
25582570
NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
25592571
TTI::CastContextHint::None, CostKind);
25602572
if (!C0->hasOneUse())
25612573
NewCost += CostC0;
2562-
if (!C1->hasOneUse())
2563-
NewCost += CostC1;
2574+
if (IsBinaryShuffle) {
2575+
InstructionCost CostC1 =
2576+
TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
2577+
TTI::CastContextHint::None, CostKind);
2578+
OldCost += CostC1;
2579+
if (!C1->hasOneUse())
2580+
NewCost += CostC1;
2581+
}
25642582

25652583
LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
25662584
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
25672585
<< "\n");
25682586
if (NewCost > OldCost)
25692587
return false;
25702588

2571-
Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0),
2572-
C1->getOperand(0), NewMask);
2589+
Value *Shuf;
2590+
if (IsBinaryShuffle)
2591+
Shuf = Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0),
2592+
NewMask);
2593+
else
2594+
Shuf = Builder.CreateShuffleVector(C0->getOperand(0), NewMask);
2595+
25732596
Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
25742597

25752598
// Intersect flags from the old casts.
25762599
if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
25772600
NewInst->copyIRFlags(C0);
2578-
NewInst->andIRFlags(C1);
2601+
if (IsBinaryShuffle)
2602+
NewInst->andIRFlags(C1);
25792603
}
25802604

25812605
Worklist.pushValue(Shuf);

0 commit comments

Comments
 (0)