diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index b35f6d71f3945..b9caf8c0df9be 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -1721,11 +1721,11 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) { Value *V0, *V1; UndefValue *U0, *U1; ArrayRef OuterMask, InnerMask0, InnerMask1; - if (!match(&I, m_Shuffle(m_OneUse(m_Shuffle(m_Value(V0), m_UndefValue(U0), - m_Mask(InnerMask0))), - m_OneUse(m_Shuffle(m_Value(V1), m_UndefValue(U1), - m_Mask(InnerMask1))), - m_Mask(OuterMask)))) + if (!match(&I, + m_Shuffle( + m_Shuffle(m_Value(V0), m_UndefValue(U0), m_Mask(InnerMask0)), + m_Shuffle(m_Value(V1), m_UndefValue(U1), m_Mask(InnerMask1)), + m_Mask(OuterMask)))) return false; auto *ShufI0 = dyn_cast(I.getOperand(0)); @@ -1769,17 +1769,24 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) { // Try to merge the shuffles if the new shuffle is not costly. TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; - InstructionCost OldCost = + InstructionCost InnerCost0 = TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy, - InnerMask0, CostKind, 0, nullptr, {V0, U0}, ShufI0) + + InnerMask0, CostKind, 0, nullptr, {V0, U0}, ShufI0); + InstructionCost InnerCost1 = TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy, - InnerMask1, CostKind, 0, nullptr, {V1, U1}, ShufI1) + + InnerMask1, CostKind, 0, nullptr, {V1, U1}, ShufI1); + InstructionCost OuterCost = TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy, OuterMask, CostKind, 0, nullptr, {ShufI0, ShufI1}, &I); + InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost; InstructionCost NewCost = TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleSrcTy, NewMask, CostKind, 0, nullptr, {V0, V1}); + if (!ShufI0->hasOneUse()) + NewCost += InnerCost0; + if (!ShufI1->hasOneUse()) + NewCost += InnerCost1; LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll b/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll index 7472e1bc52bb8..d4446e27742f8 100644 --- a/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll @@ -1080,11 +1080,9 @@ define <16 x i64> @operandbundles(<4 x i64> %a, <4 x i64> %b, <4 x i64> %c) { define <8 x i8> @operandbundles_first(<8 x i8> %a) { ; CHECK-LABEL: @operandbundles_first( -; CHECK-NEXT: [[AB:%.*]] = shufflevector <8 x i8> [[A:%.*]], <8 x i8> poison, <4 x i32> -; CHECK-NEXT: [[AT:%.*]] = shufflevector <8 x i8> [[A]], <8 x i8> poison, <4 x i32> +; CHECK-NEXT: [[AT:%.*]] = shufflevector <8 x i8> [[A:%.*]], <8 x i8> poison, <4 x i32> ; CHECK-NEXT: [[ABT:%.*]] = call <4 x i8> @llvm.abs.v4i8(<4 x i8> [[AT]], i1 false) [ "jl_roots"(ptr addrspace(10) null, ptr addrspace(10) null) ] -; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i8> [[AT]], <4 x i8> [[AB]], <8 x i32> -; CHECK-NEXT: [[R:%.*]] = call <8 x i8> @llvm.abs.v8i8(<8 x i8> [[TMP1]], i1 false) +; CHECK-NEXT: [[R:%.*]] = call <8 x i8> @llvm.abs.v8i8(<8 x i8> [[A]], i1 false) ; CHECK-NEXT: ret <8 x i8> [[R]] ; %ab = shufflevector <8 x i8> %a, <8 x i8> poison, <4 x i32> @@ -1098,10 +1096,8 @@ define <8 x i8> @operandbundles_first(<8 x i8> %a) { define <8 x i8> @operandbundles_second(<8 x i8> %a) { ; CHECK-LABEL: @operandbundles_second( ; CHECK-NEXT: [[AB:%.*]] = shufflevector <8 x i8> [[A:%.*]], <8 x i8> poison, <4 x i32> -; CHECK-NEXT: [[AT:%.*]] = shufflevector <8 x i8> [[A]], <8 x i8> poison, <4 x i32> ; CHECK-NEXT: [[ABB:%.*]] = call <4 x i8> @llvm.abs.v4i8(<4 x i8> [[AB]], i1 false) [ "jl_roots"(ptr addrspace(10) null, ptr addrspace(10) null) ] -; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i8> [[AT]], <4 x i8> [[AB]], <8 x i32> -; CHECK-NEXT: [[R:%.*]] = call <8 x i8> @llvm.abs.v8i8(<8 x i8> [[TMP1]], i1 false) +; CHECK-NEXT: [[R:%.*]] = call <8 x i8> @llvm.abs.v8i8(<8 x i8> [[A]], i1 false) ; CHECK-NEXT: ret <8 x i8> [[R]] ; %ab = shufflevector <8 x i8> %a, <8 x i8> poison, <4 x i32>