@@ -2477,21 +2477,28 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
2477
2477
if (!match (&I, m_Shuffle (m_Value (V0), m_Value (V1), m_Mask (OldMask))))
2478
2478
return false ;
2479
2479
2480
+ // Check whether this is a unary shuffle.
2481
+ // TODO: should this be extended to match undef or unused values.
2482
+ bool IsBinaryShuffle = !isa<PoisonValue>(V1);
2483
+ LLVM_DEBUG (dbgs () << " Is binary shuffle: " << IsBinaryShuffle << " \n " );
2484
+
2480
2485
auto *C0 = dyn_cast<CastInst>(V0);
2481
2486
auto *C1 = dyn_cast<CastInst>(V1);
2482
- if (!C0 || !C1)
2487
+ if (!C0 || (IsBinaryShuffle && !C1) )
2483
2488
return false ;
2484
2489
2485
2490
Instruction::CastOps Opcode = C0->getOpcode ();
2486
- if (C0->getSrcTy () != C1->getSrcTy ())
2487
- return false ;
2488
2491
2489
- // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2490
- if (Opcode != C1->getOpcode ()) {
2491
- if (match (C0, m_SExtLike (m_Value ())) && match (C1, m_SExtLike (m_Value ())))
2492
- Opcode = Instruction::SExt;
2493
- else
2492
+ if (IsBinaryShuffle) {
2493
+ if (C0->getSrcTy () != C1->getSrcTy ())
2494
2494
return false ;
2495
+ // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2496
+ if (Opcode != C1->getOpcode ()) {
2497
+ if (match (C0, m_SExtLike (m_Value ())) && match (C1, m_SExtLike (m_Value ())))
2498
+ Opcode = Instruction::SExt;
2499
+ else
2500
+ return false ;
2501
+ }
2495
2502
}
2496
2503
2497
2504
auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType ());
@@ -2534,38 +2541,52 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
2534
2541
InstructionCost CostC0 =
2535
2542
TTI.getCastInstrCost (C0->getOpcode (), CastDstTy, CastSrcTy,
2536
2543
TTI::CastContextHint::None, CostKind);
2537
- InstructionCost CostC1 =
2538
- TTI.getCastInstrCost (C1->getOpcode (), CastDstTy, CastSrcTy,
2539
- TTI::CastContextHint::None, CostKind);
2540
- InstructionCost OldCost = CostC0 + CostC1;
2541
- OldCost +=
2542
- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, ShuffleDstTy,
2543
- CastDstTy, OldMask, CostKind, 0 , nullptr , {}, &I);
2544
2544
2545
- InstructionCost NewCost =
2546
- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, NewShuffleDstTy,
2547
- CastSrcTy, NewMask, CostKind);
2545
+ TargetTransformInfo::ShuffleKind ShuffleKind;
2546
+ if (IsBinaryShuffle)
2547
+ ShuffleKind = TargetTransformInfo::SK_PermuteTwoSrc;
2548
+ else
2549
+ ShuffleKind = TargetTransformInfo::SK_PermuteSingleSrc;
2550
+
2551
+ InstructionCost OldCost = CostC0;
2552
+ OldCost += TTI.getShuffleCost (ShuffleKind, ShuffleDstTy, CastDstTy, OldMask,
2553
+ CostKind, 0 , nullptr , {}, &I);
2554
+
2555
+ InstructionCost NewCost = TTI.getShuffleCost (ShuffleKind, NewShuffleDstTy,
2556
+ CastSrcTy, NewMask, CostKind);
2548
2557
NewCost += TTI.getCastInstrCost (Opcode, ShuffleDstTy, NewShuffleDstTy,
2549
2558
TTI::CastContextHint::None, CostKind);
2550
2559
if (!C0->hasOneUse ())
2551
2560
NewCost += CostC0;
2552
- if (!C1->hasOneUse ())
2553
- NewCost += CostC1;
2561
+ if (IsBinaryShuffle) {
2562
+ InstructionCost CostC1 =
2563
+ TTI.getCastInstrCost (C1->getOpcode (), CastDstTy, CastSrcTy,
2564
+ TTI::CastContextHint::None, CostKind);
2565
+ OldCost += CostC1;
2566
+ if (!C1->hasOneUse ())
2567
+ NewCost += CostC1;
2568
+ }
2554
2569
2555
2570
LLVM_DEBUG (dbgs () << " Found a shuffle feeding two casts: " << I
2556
2571
<< " \n OldCost: " << OldCost << " vs NewCost: " << NewCost
2557
2572
<< " \n " );
2558
2573
if (NewCost > OldCost)
2559
2574
return false ;
2560
2575
2561
- Value *Shuf = Builder.CreateShuffleVector (C0->getOperand (0 ),
2562
- C1->getOperand (0 ), NewMask);
2576
+ Value *Shuf;
2577
+ if (IsBinaryShuffle)
2578
+ Shuf = Builder.CreateShuffleVector (C0->getOperand (0 ), C1->getOperand (0 ),
2579
+ NewMask);
2580
+ else
2581
+ Shuf = Builder.CreateShuffleVector (C0->getOperand (0 ), NewMask);
2582
+
2563
2583
Value *Cast = Builder.CreateCast (Opcode, Shuf, ShuffleDstTy);
2564
2584
2565
2585
// Intersect flags from the old casts.
2566
2586
if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
2567
2587
NewInst->copyIRFlags (C0);
2568
- NewInst->andIRFlags (C1);
2588
+ if (IsBinaryShuffle)
2589
+ NewInst->andIRFlags (C1);
2569
2590
}
2570
2591
2571
2592
Worklist.pushValue (Shuf);
0 commit comments