@@ -684,10 +684,10 @@ bool VectorCombine::foldInsExtFNeg(Instruction &I) {
684684// / destination type followed by shuffle. This can enable further transforms by
685685// / moving bitcasts or shuffles together.
686686bool VectorCombine::foldBitcastShuffle (Instruction &I) {
687- Value *V ;
687+ Value *V0 ;
688688 ArrayRef<int > Mask;
689- if (!match (&I, m_BitCast (
690- m_OneUse ( m_Shuffle (m_Value (V ), m_Undef (), m_Mask (Mask))))))
689+ if (!match (&I, m_BitCast (m_OneUse (
690+ m_Shuffle (m_Value (V0 ), m_Undef (), m_Mask (Mask))))))
691691 return false ;
692692
693693 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
@@ -696,7 +696,7 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) {
696696 // 2) Disallow non-vector casts.
697697 // TODO: We could allow any shuffle.
698698 auto *DestTy = dyn_cast<FixedVectorType>(I.getType ());
699- auto *SrcTy = dyn_cast<FixedVectorType>(V ->getType ());
699+ auto *SrcTy = dyn_cast<FixedVectorType>(V0 ->getType ());
700700 if (!DestTy || !SrcTy)
701701 return false ;
702702
@@ -724,20 +724,31 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) {
724724 // Bitcast the shuffle src - keep its original width but using the destination
725725 // scalar type.
726726 unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits () / DestEltSize;
727- auto *ShuffleTy = FixedVectorType::get (DestTy->getScalarType (), NumSrcElts);
728-
729- // The new shuffle must not cost more than the old shuffle. The bitcast is
730- // moved ahead of the shuffle, so assume that it has the same cost as before.
731- InstructionCost DestCost = TTI.getShuffleCost (
732- TargetTransformInfo::SK_PermuteSingleSrc, ShuffleTy, NewMask);
727+ auto *NewShuffleTy =
728+ FixedVectorType::get (DestTy->getScalarType (), NumSrcElts);
729+ auto *OldShuffleTy =
730+ FixedVectorType::get (SrcTy->getScalarType (), Mask.size ());
731+
732+ // The new shuffle must not cost more than the old shuffle.
733+ TargetTransformInfo::TargetCostKind CK =
734+ TargetTransformInfo::TCK_RecipThroughput;
735+ TargetTransformInfo::ShuffleKind SK =
736+ TargetTransformInfo::SK_PermuteSingleSrc;
737+
738+ InstructionCost DestCost =
739+ TTI.getShuffleCost (SK, NewShuffleTy, NewMask, CK) +
740+ TTI.getCastInstrCost (Instruction::BitCast, NewShuffleTy, SrcTy,
741+ TargetTransformInfo::CastContextHint::None, CK);
733742 InstructionCost SrcCost =
734- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, SrcTy, Mask);
743+ TTI.getShuffleCost (SK, SrcTy, Mask, CK) +
744+ TTI.getCastInstrCost (Instruction::BitCast, DestTy, OldShuffleTy,
745+ TargetTransformInfo::CastContextHint::None, CK);
735746 if (DestCost > SrcCost || !DestCost.isValid ())
736747 return false ;
737748
738- // bitcast (shuf V , MaskC) --> shuf (bitcast V ), MaskC'
749+ // bitcast (shuf V0 , MaskC) --> shuf (bitcast V0 ), MaskC'
739750 ++NumShufOfBitcast;
740- Value *CastV = Builder.CreateBitCast (V, ShuffleTy );
751+ Value *CastV = Builder.CreateBitCast (V0, NewShuffleTy );
741752 Value *Shuf = Builder.CreateShuffleVector (CastV, NewMask);
742753 replaceValue (I, *Shuf);
743754 return true ;
0 commit comments