diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 61f9b83b5c697..f1b80a5b69d42 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -1871,6 +1871,7 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) { } /// Try to convert any of: +/// "shuffle (shuffle x, y), (shuffle y, x)" /// "shuffle (shuffle x, undef), (shuffle y, undef)" /// "shuffle (shuffle x, undef), y" /// "shuffle x, (shuffle y, undef)" @@ -1883,68 +1884,93 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) { return false; ArrayRef InnerMask0, InnerMask1; - Value *V0 = nullptr, *V1 = nullptr; - UndefValue *U0 = nullptr, *U1 = nullptr; - bool Match0 = match( - OuterV0, m_Shuffle(m_Value(V0), m_UndefValue(U0), m_Mask(InnerMask0))); - bool Match1 = match( - OuterV1, m_Shuffle(m_Value(V1), m_UndefValue(U1), m_Mask(InnerMask1))); + Value *X0, *X1, *Y0, *Y1; + bool Match0 = + match(OuterV0, m_Shuffle(m_Value(X0), m_Value(Y0), m_Mask(InnerMask0))); + bool Match1 = + match(OuterV1, m_Shuffle(m_Value(X1), m_Value(Y1), m_Mask(InnerMask1))); if (!Match0 && !Match1) return false; - V0 = Match0 ? V0 : OuterV0; - V1 = Match1 ? V1 : OuterV1; + X0 = Match0 ? X0 : OuterV0; + Y0 = Match0 ? Y0 : OuterV0; + X1 = Match1 ? X1 : OuterV1; + Y1 = Match1 ? Y1 : OuterV1; auto *ShuffleDstTy = dyn_cast(I.getType()); - auto *ShuffleSrcTy = dyn_cast(V0->getType()); - auto *ShuffleImmTy = dyn_cast(I.getOperand(0)->getType()); + auto *ShuffleSrcTy = dyn_cast(X0->getType()); + auto *ShuffleImmTy = dyn_cast(OuterV0->getType()); if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy || - V0->getType() != V1->getType()) + X0->getType() != X1->getType()) return false; unsigned NumSrcElts = ShuffleSrcTy->getNumElements(); unsigned NumImmElts = ShuffleImmTy->getNumElements(); - // Bail if either inner masks reference a RHS undef arg. - if ((Match0 && !isa(U0) && - any_of(InnerMask0, [&](int M) { return M >= (int)NumSrcElts; })) || - (Match1 && !isa(U1) && - any_of(InnerMask1, [&](int M) { return M >= (int)NumSrcElts; }))) - return false; - - // Merge shuffles - replace index to the RHS poison arg with PoisonMaskElem, + // Attempt to merge shuffles, matching upto 2 source operands. + // Replace index to a poison arg with PoisonMaskElem. + // Bail if either inner masks reference an undef arg. SmallVector NewMask(OuterMask); + Value *NewX = nullptr, *NewY = nullptr; for (int &M : NewMask) { + Value *Src = nullptr; if (0 <= M && M < (int)NumImmElts) { - if (Match0) - M = (InnerMask0[M] >= (int)NumSrcElts) ? PoisonMaskElem : InnerMask0[M]; + Src = OuterV0; + if (Match0) { + M = InnerMask0[M]; + Src = M >= (int)NumSrcElts ? Y0 : X0; + M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M; + } } else if (M >= (int)NumImmElts) { + Src = OuterV1; + M -= NumImmElts; if (Match1) { - if (InnerMask1[M - NumImmElts] >= (int)NumSrcElts) - M = PoisonMaskElem; - else - M = InnerMask1[M - NumImmElts] + (V0 == V1 ? 0 : NumSrcElts); + M = InnerMask1[M]; + Src = M >= (int)NumSrcElts ? Y1 : X1; + M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M; } } + if (Src && M != PoisonMaskElem) { + assert(0 <= M && M < (int)NumSrcElts && "Unexpected shuffle mask index"); + if (isa(Src)) { + // We've referenced an undef element - if its poison, update the shuffle + // mask, else bail. + if (!isa(Src)) + return false; + M = PoisonMaskElem; + continue; + } + if (!NewX || NewX == Src) { + NewX = Src; + continue; + } + if (!NewY || NewY == Src) { + M += NumSrcElts; + NewY = Src; + continue; + } + return false; + } } + if (!NewX) + return PoisonValue::get(ShuffleDstTy); + if (!NewY) + NewY = PoisonValue::get(ShuffleSrcTy); + // Have we folded to an Identity shuffle? if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) { - replaceValue(I, *V0); + replaceValue(I, *NewX); return true; } // Try to merge the shuffles if the new shuffle is not costly. InstructionCost InnerCost0 = 0; if (Match0) - InnerCost0 = TTI.getShuffleCost( - TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy, InnerMask0, - CostKind, 0, nullptr, {V0, U0}, cast(OuterV0)); + InnerCost0 = TTI.getInstructionCost(cast(OuterV0), CostKind); InstructionCost InnerCost1 = 0; if (Match1) - InnerCost1 = TTI.getShuffleCost( - TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy, InnerMask1, - CostKind, 0, nullptr, {V1, U1}, cast(OuterV1)); + InnerCost1 = TTI.getInstructionCost(cast(OuterV1), CostKind); InstructionCost OuterCost = TTI.getShuffleCost( TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy, OuterMask, CostKind, @@ -1952,9 +1978,12 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) { InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost; - InstructionCost NewCost = - TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleSrcTy, - NewMask, CostKind, 0, nullptr, {V0, V1}); + bool IsUnary = all_of(NewMask, [&](int M) { return M < (int)NumSrcElts; }); + TargetTransformInfo::ShuffleKind SK = + IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc + : TargetTransformInfo::SK_PermuteTwoSrc; + InstructionCost NewCost = TTI.getShuffleCost( + SK, ShuffleSrcTy, NewMask, CostKind, 0, nullptr, {NewX, NewY}); if (!OuterV0->hasOneUse()) NewCost += InnerCost0; if (!OuterV1->hasOneUse()) @@ -1966,13 +1995,7 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) { if (NewCost > OldCost) return false; - // Clear unused sources to poison. - if (none_of(NewMask, [&](int M) { return 0 <= M && M < (int)NumSrcElts; })) - V0 = PoisonValue::get(ShuffleSrcTy); - if (none_of(NewMask, [&](int M) { return (int)NumSrcElts <= M; })) - V1 = PoisonValue::get(ShuffleSrcTy); - - Value *Shuf = Builder.CreateShuffleVector(V0, V1, NewMask); + Value *Shuf = Builder.CreateShuffleVector(NewX, NewY, NewMask); replaceValue(I, *Shuf); return true; } diff --git a/llvm/test/Transforms/PhaseOrdering/X86/hadd.ll b/llvm/test/Transforms/PhaseOrdering/X86/hadd.ll index 7ca06c4a8791e..0a3599b7e7ff6 100644 --- a/llvm/test/Transforms/PhaseOrdering/X86/hadd.ll +++ b/llvm/test/Transforms/PhaseOrdering/X86/hadd.ll @@ -391,21 +391,11 @@ define <4 x i32> @add_v4i32_01uu(<4 x i32> %a, <4 x i32> %b) { ; define <8 x i32> @add_v8i32_01234567(<8 x i32> %a, <8 x i32> %b) { -; SSE-LABEL: @add_v8i32_01234567( -; SSE-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> [[B:%.*]], <4 x i32> -; SSE-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> [[B]], <4 x i32> -; SSE-NEXT: [[TMP3:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> [[B]], <4 x i32> -; SSE-NEXT: [[TMP4:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> [[B]], <4 x i32> -; SSE-NEXT: [[TMP5:%.*]] = add <4 x i32> [[TMP1]], [[TMP3]] -; SSE-NEXT: [[TMP6:%.*]] = add <4 x i32> [[TMP2]], [[TMP4]] -; SSE-NEXT: [[TMP7:%.*]] = shufflevector <4 x i32> [[TMP5]], <4 x i32> [[TMP6]], <8 x i32> -; SSE-NEXT: ret <8 x i32> [[TMP7]] -; -; AVX-LABEL: @add_v8i32_01234567( -; AVX-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> [[B:%.*]], <8 x i32> -; AVX-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> [[B]], <8 x i32> -; AVX-NEXT: [[TMP3:%.*]] = add <8 x i32> [[TMP1]], [[TMP2]] -; AVX-NEXT: ret <8 x i32> [[TMP3]] +; CHECK-LABEL: @add_v8i32_01234567( +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> [[B:%.*]], <8 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> [[B]], <8 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = add <8 x i32> [[TMP1]], [[TMP2]] +; CHECK-NEXT: ret <8 x i32> [[TMP3]] ; %a0 = extractelement <8 x i32> %a, i32 0 %a1 = extractelement <8 x i32> %a, i32 1 @@ -786,21 +776,11 @@ define <4 x float> @add_v4f32_01uu(<4 x float> %a, <4 x float> %b) { ; define <8 x float> @add_v8f32_01234567(<8 x float> %a, <8 x float> %b) { -; SSE-LABEL: @add_v8f32_01234567( -; SSE-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> [[B:%.*]], <4 x i32> -; SSE-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[A]], <8 x float> [[B]], <4 x i32> -; SSE-NEXT: [[TMP3:%.*]] = shufflevector <8 x float> [[A]], <8 x float> [[B]], <4 x i32> -; SSE-NEXT: [[TMP4:%.*]] = shufflevector <8 x float> [[A]], <8 x float> [[B]], <4 x i32> -; SSE-NEXT: [[TMP5:%.*]] = fadd <4 x float> [[TMP1]], [[TMP3]] -; SSE-NEXT: [[TMP6:%.*]] = fadd <4 x float> [[TMP2]], [[TMP4]] -; SSE-NEXT: [[TMP7:%.*]] = shufflevector <4 x float> [[TMP5]], <4 x float> [[TMP6]], <8 x i32> -; SSE-NEXT: ret <8 x float> [[TMP7]] -; -; AVX-LABEL: @add_v8f32_01234567( -; AVX-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> [[B:%.*]], <8 x i32> -; AVX-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[A]], <8 x float> [[B]], <8 x i32> -; AVX-NEXT: [[TMP3:%.*]] = fadd <8 x float> [[TMP1]], [[TMP2]] -; AVX-NEXT: ret <8 x float> [[TMP3]] +; CHECK-LABEL: @add_v8f32_01234567( +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> [[B:%.*]], <8 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[A]], <8 x float> [[B]], <8 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = fadd <8 x float> [[TMP1]], [[TMP2]] +; CHECK-NEXT: ret <8 x float> [[TMP3]] ; %a0 = extractelement <8 x float> %a, i32 0 %a1 = extractelement <8 x float> %a, i32 1 @@ -969,21 +949,11 @@ define <2 x double> @add_v2f64_0u(<2 x double> %a, <2 x double> %b) { ; define <4 x double> @add_v4f64_0123(<4 x double> %a, <4 x double> %b) { -; SSE-LABEL: @add_v4f64_0123( -; SSE-NEXT: [[TMP1:%.*]] = shufflevector <4 x double> [[A:%.*]], <4 x double> [[B:%.*]], <2 x i32> -; SSE-NEXT: [[TMP2:%.*]] = shufflevector <4 x double> [[A]], <4 x double> [[B]], <2 x i32> -; SSE-NEXT: [[TMP3:%.*]] = shufflevector <4 x double> [[A]], <4 x double> [[B]], <2 x i32> -; SSE-NEXT: [[TMP4:%.*]] = shufflevector <4 x double> [[A]], <4 x double> [[B]], <2 x i32> -; SSE-NEXT: [[TMP5:%.*]] = fadd <2 x double> [[TMP1]], [[TMP3]] -; SSE-NEXT: [[TMP6:%.*]] = fadd <2 x double> [[TMP2]], [[TMP4]] -; SSE-NEXT: [[TMP7:%.*]] = shufflevector <2 x double> [[TMP5]], <2 x double> [[TMP6]], <4 x i32> -; SSE-NEXT: ret <4 x double> [[TMP7]] -; -; AVX-LABEL: @add_v4f64_0123( -; AVX-NEXT: [[TMP1:%.*]] = shufflevector <4 x double> [[A:%.*]], <4 x double> [[B:%.*]], <4 x i32> -; AVX-NEXT: [[TMP2:%.*]] = shufflevector <4 x double> [[A]], <4 x double> [[B]], <4 x i32> -; AVX-NEXT: [[TMP3:%.*]] = fadd <4 x double> [[TMP1]], [[TMP2]] -; AVX-NEXT: ret <4 x double> [[TMP3]] +; CHECK-LABEL: @add_v4f64_0123( +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x double> [[A:%.*]], <4 x double> [[B:%.*]], <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x double> [[A]], <4 x double> [[B]], <4 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = fadd <4 x double> [[TMP1]], [[TMP2]] +; CHECK-NEXT: ret <4 x double> [[TMP3]] ; %a0 = extractelement <4 x double> %a, i32 0 %a1 = extractelement <4 x double> %a, i32 1 diff --git a/llvm/test/Transforms/VectorCombine/AArch64/select-shuffle.ll b/llvm/test/Transforms/VectorCombine/AArch64/select-shuffle.ll index ebd2c5bd2574b..3a3ba74663b93 100644 --- a/llvm/test/Transforms/VectorCombine/AArch64/select-shuffle.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/select-shuffle.ll @@ -952,8 +952,7 @@ define <16 x i32> @testoutofbounds(<16 x i32> %x, <16 x i32> %y) { ; CHECK-NEXT: [[A:%.*]] = add nsw <16 x i32> [[S1]], [[S2]] ; CHECK-NEXT: [[B:%.*]] = sub nsw <16 x i32> [[S1]], [[S2]] ; CHECK-NEXT: [[S3:%.*]] = shufflevector <16 x i32> [[A]], <16 x i32> [[B]], <16 x i32> -; CHECK-NEXT: [[S4:%.*]] = shufflevector <16 x i32> [[S3]], <16 x i32> poison, <16 x i32> -; CHECK-NEXT: [[ADD:%.*]] = add <16 x i32> [[S3]], [[S4]] +; CHECK-NEXT: [[ADD:%.*]] = add <16 x i32> [[S3]], [[B]] ; CHECK-NEXT: ret <16 x i32> [[ADD]] ; %s1 = shufflevector <16 x i32> %x, <16 x i32> %y, <16 x i32> @@ -973,8 +972,7 @@ define <64 x i32> @testlargerextrashuffle2(i32 %call.i, <16 x i32> %0) { ; CHECK-NEXT: [[TMP2:%.*]] = insertelement <16 x i32> [[TMP0]], i32 [[CALL_I]], i32 15 ; CHECK-NEXT: [[TMP3:%.*]] = sub <16 x i32> [[TMP1]], [[TMP2]] ; CHECK-NEXT: [[TMP4:%.*]] = add <16 x i32> [[TMP1]], [[TMP2]] -; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <16 x i32> [[TMP3]], <16 x i32> [[TMP4]], <16 x i32> -; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <16 x i32> [[TMP5]], <16 x i32> poison, <64 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <16 x i32> [[TMP3]], <16 x i32> [[TMP4]], <64 x i32> ; CHECK-NEXT: ret <64 x i32> [[TMP6]] ; entry: