@@ -2487,21 +2487,31 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
24872487 if (!match (&I, m_Shuffle (m_Value (V0), m_Value (V1), m_Mask (OldMask))))
24882488 return false ;
24892489
2490+ // Check whether this is a binary shuffle.
2491+ bool IsBinaryShuffle = !isa<UndefValue>(V1);
2492+
24902493 auto *C0 = dyn_cast<CastInst>(V0);
24912494 auto *C1 = dyn_cast<CastInst>(V1);
2492- if (!C0 || !C1)
2495+ if (!C0 || (IsBinaryShuffle && !C1) )
24932496 return false ;
24942497
24952498 Instruction::CastOps Opcode = C0->getOpcode ();
2496- if (C0->getSrcTy () != C1->getSrcTy ())
2499+
2500+ // If this is allowed, foldShuffleOfCastops can get stuck in a loop
2501+ // with foldBitcastOfShuffle. Reject in favor of foldBitcastOfShuffle.
2502+ if (!IsBinaryShuffle && Opcode == Instruction::BitCast)
24972503 return false ;
24982504
2499- // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2500- if (Opcode != C1->getOpcode ()) {
2501- if (match (C0, m_SExtLike (m_Value ())) && match (C1, m_SExtLike (m_Value ())))
2502- Opcode = Instruction::SExt;
2503- else
2505+ if (IsBinaryShuffle) {
2506+ if (C0->getSrcTy () != C1->getSrcTy ())
25042507 return false ;
2508+ // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2509+ if (Opcode != C1->getOpcode ()) {
2510+ if (match (C0, m_SExtLike (m_Value ())) && match (C1, m_SExtLike (m_Value ())))
2511+ Opcode = Instruction::SExt;
2512+ else
2513+ return false ;
2514+ }
25052515 }
25062516
25072517 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType ());
@@ -2544,38 +2554,52 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
25442554 InstructionCost CostC0 =
25452555 TTI.getCastInstrCost (C0->getOpcode (), CastDstTy, CastSrcTy,
25462556 TTI::CastContextHint::None, CostKind);
2547- InstructionCost CostC1 =
2548- TTI.getCastInstrCost (C1->getOpcode (), CastDstTy, CastSrcTy,
2549- TTI::CastContextHint::None, CostKind);
2550- InstructionCost OldCost = CostC0 + CostC1;
2551- OldCost +=
2552- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, ShuffleDstTy,
2553- CastDstTy, OldMask, CostKind, 0 , nullptr , {}, &I);
25542557
2555- InstructionCost NewCost =
2556- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, NewShuffleDstTy,
2557- CastSrcTy, NewMask, CostKind);
2558+ TargetTransformInfo::ShuffleKind ShuffleKind;
2559+ if (IsBinaryShuffle)
2560+ ShuffleKind = TargetTransformInfo::SK_PermuteTwoSrc;
2561+ else
2562+ ShuffleKind = TargetTransformInfo::SK_PermuteSingleSrc;
2563+
2564+ InstructionCost OldCost = CostC0;
2565+ OldCost += TTI.getShuffleCost (ShuffleKind, ShuffleDstTy, CastDstTy, OldMask,
2566+ CostKind, 0 , nullptr , {}, &I);
2567+
2568+ InstructionCost NewCost = TTI.getShuffleCost (ShuffleKind, NewShuffleDstTy,
2569+ CastSrcTy, NewMask, CostKind);
25582570 NewCost += TTI.getCastInstrCost (Opcode, ShuffleDstTy, NewShuffleDstTy,
25592571 TTI::CastContextHint::None, CostKind);
25602572 if (!C0->hasOneUse ())
25612573 NewCost += CostC0;
2562- if (!C1->hasOneUse ())
2563- NewCost += CostC1;
2574+ if (IsBinaryShuffle) {
2575+ InstructionCost CostC1 =
2576+ TTI.getCastInstrCost (C1->getOpcode (), CastDstTy, CastSrcTy,
2577+ TTI::CastContextHint::None, CostKind);
2578+ OldCost += CostC1;
2579+ if (!C1->hasOneUse ())
2580+ NewCost += CostC1;
2581+ }
25642582
25652583 LLVM_DEBUG (dbgs () << " Found a shuffle feeding two casts: " << I
25662584 << " \n OldCost: " << OldCost << " vs NewCost: " << NewCost
25672585 << " \n " );
25682586 if (NewCost > OldCost)
25692587 return false ;
25702588
2571- Value *Shuf = Builder.CreateShuffleVector (C0->getOperand (0 ),
2572- C1->getOperand (0 ), NewMask);
2589+ Value *Shuf;
2590+ if (IsBinaryShuffle)
2591+ Shuf = Builder.CreateShuffleVector (C0->getOperand (0 ), C1->getOperand (0 ),
2592+ NewMask);
2593+ else
2594+ Shuf = Builder.CreateShuffleVector (C0->getOperand (0 ), NewMask);
2595+
25732596 Value *Cast = Builder.CreateCast (Opcode, Shuf, ShuffleDstTy);
25742597
25752598 // Intersect flags from the old casts.
25762599 if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
25772600 NewInst->copyIRFlags (C0);
2578- NewInst->andIRFlags (C1);
2601+ if (IsBinaryShuffle)
2602+ NewInst->andIRFlags (C1);
25792603 }
25802604
25812605 Worklist.pushValue (Shuf);
0 commit comments