Skip to content

Commit ed4c1df

Browse files
committed
[VectorCombine] foldShuffleOfCastops - handle unary shuffles
1 parent 96d5567 commit ed4c1df

File tree

1 file changed

+44
-23
lines changed

1 file changed

+44
-23
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2477,21 +2477,28 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
24772477
if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
24782478
return false;
24792479

2480+
// Check whether this is a unary shuffle.
2481+
// TODO: should this be extended to match undef or unused values.
2482+
bool IsBinaryShuffle = !isa<PoisonValue>(V1);
2483+
LLVM_DEBUG(dbgs() << "Is binary shuffle: " << IsBinaryShuffle << "\n");
2484+
24802485
auto *C0 = dyn_cast<CastInst>(V0);
24812486
auto *C1 = dyn_cast<CastInst>(V1);
2482-
if (!C0 || !C1)
2487+
if (!C0 || (IsBinaryShuffle && !C1))
24832488
return false;
24842489

24852490
Instruction::CastOps Opcode = C0->getOpcode();
2486-
if (C0->getSrcTy() != C1->getSrcTy())
2487-
return false;
24882491

2489-
// Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2490-
if (Opcode != C1->getOpcode()) {
2491-
if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
2492-
Opcode = Instruction::SExt;
2493-
else
2492+
if (IsBinaryShuffle) {
2493+
if (C0->getSrcTy() != C1->getSrcTy())
24942494
return false;
2495+
// Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2496+
if (Opcode != C1->getOpcode()) {
2497+
if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
2498+
Opcode = Instruction::SExt;
2499+
else
2500+
return false;
2501+
}
24952502
}
24962503

24972504
auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
@@ -2534,38 +2541,52 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
25342541
InstructionCost CostC0 =
25352542
TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
25362543
TTI::CastContextHint::None, CostKind);
2537-
InstructionCost CostC1 =
2538-
TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
2539-
TTI::CastContextHint::None, CostKind);
2540-
InstructionCost OldCost = CostC0 + CostC1;
2541-
OldCost +=
2542-
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleDstTy,
2543-
CastDstTy, OldMask, CostKind, 0, nullptr, {}, &I);
25442544

2545-
InstructionCost NewCost =
2546-
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, NewShuffleDstTy,
2547-
CastSrcTy, NewMask, CostKind);
2545+
TargetTransformInfo::ShuffleKind ShuffleKind;
2546+
if (IsBinaryShuffle)
2547+
ShuffleKind = TargetTransformInfo::SK_PermuteTwoSrc;
2548+
else
2549+
ShuffleKind = TargetTransformInfo::SK_PermuteSingleSrc;
2550+
2551+
InstructionCost OldCost = CostC0;
2552+
OldCost += TTI.getShuffleCost(ShuffleKind, ShuffleDstTy, CastDstTy, OldMask,
2553+
CostKind, 0, nullptr, {}, &I);
2554+
2555+
InstructionCost NewCost = TTI.getShuffleCost(ShuffleKind, NewShuffleDstTy,
2556+
CastSrcTy, NewMask, CostKind);
25482557
NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
25492558
TTI::CastContextHint::None, CostKind);
25502559
if (!C0->hasOneUse())
25512560
NewCost += CostC0;
2552-
if (!C1->hasOneUse())
2553-
NewCost += CostC1;
2561+
if (IsBinaryShuffle) {
2562+
InstructionCost CostC1 =
2563+
TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
2564+
TTI::CastContextHint::None, CostKind);
2565+
OldCost += CostC1;
2566+
if (!C1->hasOneUse())
2567+
NewCost += CostC1;
2568+
}
25542569

25552570
LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
25562571
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
25572572
<< "\n");
25582573
if (NewCost > OldCost)
25592574
return false;
25602575

2561-
Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0),
2562-
C1->getOperand(0), NewMask);
2576+
Value *Shuf;
2577+
if (IsBinaryShuffle)
2578+
Shuf = Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0),
2579+
NewMask);
2580+
else
2581+
Shuf = Builder.CreateShuffleVector(C0->getOperand(0), NewMask);
2582+
25632583
Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
25642584

25652585
// Intersect flags from the old casts.
25662586
if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
25672587
NewInst->copyIRFlags(C0);
2568-
NewInst->andIRFlags(C1);
2588+
if (IsBinaryShuffle)
2589+
NewInst->andIRFlags(C1);
25692590
}
25702591

25712592
Worklist.pushValue(Shuf);

0 commit comments

Comments
 (0)