@@ -2487,21 +2487,31 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
2487
2487
if (!match (&I, m_Shuffle (m_Value (V0), m_Value (V1), m_Mask (OldMask))))
2488
2488
return false ;
2489
2489
2490
+ // Check whether this is a binary shuffle.
2491
+ bool IsBinaryShuffle = !isa<UndefValue>(V1);
2492
+
2490
2493
auto *C0 = dyn_cast<CastInst>(V0);
2491
2494
auto *C1 = dyn_cast<CastInst>(V1);
2492
- if (!C0 || !C1)
2495
+ if (!C0 || (IsBinaryShuffle && !C1) )
2493
2496
return false ;
2494
2497
2495
2498
Instruction::CastOps Opcode = C0->getOpcode ();
2496
- if (C0->getSrcTy () != C1->getSrcTy ())
2499
+
2500
+ // If this is allowed, foldShuffleOfCastops can get stuck in a loop
2501
+ // with foldBitcastOfShuffle. Reject in favor of foldBitcastOfShuffle.
2502
+ if (!IsBinaryShuffle && Opcode == Instruction::BitCast)
2497
2503
return false ;
2498
2504
2499
- // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2500
- if (Opcode != C1->getOpcode ()) {
2501
- if (match (C0, m_SExtLike (m_Value ())) && match (C1, m_SExtLike (m_Value ())))
2502
- Opcode = Instruction::SExt;
2503
- else
2505
+ if (IsBinaryShuffle) {
2506
+ if (C0->getSrcTy () != C1->getSrcTy ())
2504
2507
return false ;
2508
+ // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2509
+ if (Opcode != C1->getOpcode ()) {
2510
+ if (match (C0, m_SExtLike (m_Value ())) && match (C1, m_SExtLike (m_Value ())))
2511
+ Opcode = Instruction::SExt;
2512
+ else
2513
+ return false ;
2514
+ }
2505
2515
}
2506
2516
2507
2517
auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType ());
@@ -2544,38 +2554,52 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
2544
2554
InstructionCost CostC0 =
2545
2555
TTI.getCastInstrCost (C0->getOpcode (), CastDstTy, CastSrcTy,
2546
2556
TTI::CastContextHint::None, CostKind);
2547
- InstructionCost CostC1 =
2548
- TTI.getCastInstrCost (C1->getOpcode (), CastDstTy, CastSrcTy,
2549
- TTI::CastContextHint::None, CostKind);
2550
- InstructionCost OldCost = CostC0 + CostC1;
2551
- OldCost +=
2552
- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, ShuffleDstTy,
2553
- CastDstTy, OldMask, CostKind, 0 , nullptr , {}, &I);
2554
2557
2555
- InstructionCost NewCost =
2556
- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, NewShuffleDstTy,
2557
- CastSrcTy, NewMask, CostKind);
2558
+ TargetTransformInfo::ShuffleKind ShuffleKind;
2559
+ if (IsBinaryShuffle)
2560
+ ShuffleKind = TargetTransformInfo::SK_PermuteTwoSrc;
2561
+ else
2562
+ ShuffleKind = TargetTransformInfo::SK_PermuteSingleSrc;
2563
+
2564
+ InstructionCost OldCost = CostC0;
2565
+ OldCost += TTI.getShuffleCost (ShuffleKind, ShuffleDstTy, CastDstTy, OldMask,
2566
+ CostKind, 0 , nullptr , {}, &I);
2567
+
2568
+ InstructionCost NewCost = TTI.getShuffleCost (ShuffleKind, NewShuffleDstTy,
2569
+ CastSrcTy, NewMask, CostKind);
2558
2570
NewCost += TTI.getCastInstrCost (Opcode, ShuffleDstTy, NewShuffleDstTy,
2559
2571
TTI::CastContextHint::None, CostKind);
2560
2572
if (!C0->hasOneUse ())
2561
2573
NewCost += CostC0;
2562
- if (!C1->hasOneUse ())
2563
- NewCost += CostC1;
2574
+ if (IsBinaryShuffle) {
2575
+ InstructionCost CostC1 =
2576
+ TTI.getCastInstrCost (C1->getOpcode (), CastDstTy, CastSrcTy,
2577
+ TTI::CastContextHint::None, CostKind);
2578
+ OldCost += CostC1;
2579
+ if (!C1->hasOneUse ())
2580
+ NewCost += CostC1;
2581
+ }
2564
2582
2565
2583
LLVM_DEBUG (dbgs () << " Found a shuffle feeding two casts: " << I
2566
2584
<< " \n OldCost: " << OldCost << " vs NewCost: " << NewCost
2567
2585
<< " \n " );
2568
2586
if (NewCost > OldCost)
2569
2587
return false ;
2570
2588
2571
- Value *Shuf = Builder.CreateShuffleVector (C0->getOperand (0 ),
2572
- C1->getOperand (0 ), NewMask);
2589
+ Value *Shuf;
2590
+ if (IsBinaryShuffle)
2591
+ Shuf = Builder.CreateShuffleVector (C0->getOperand (0 ), C1->getOperand (0 ),
2592
+ NewMask);
2593
+ else
2594
+ Shuf = Builder.CreateShuffleVector (C0->getOperand (0 ), NewMask);
2595
+
2573
2596
Value *Cast = Builder.CreateCast (Opcode, Shuf, ShuffleDstTy);
2574
2597
2575
2598
// Intersect flags from the old casts.
2576
2599
if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
2577
2600
NewInst->copyIRFlags (C0);
2578
- NewInst->andIRFlags (C1);
2601
+ if (IsBinaryShuffle)
2602
+ NewInst->andIRFlags (C1);
2579
2603
}
2580
2604
2581
2605
Worklist.pushValue (Shuf);
0 commit comments