diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index b5d930ed4f7c3..04c7d84b4d95d 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -10096,7 +10096,10 @@ static bool isTargetShuffleEquivalent(MVT VT, ArrayRef Mask, if (Size != (int)ExpectedMask.size()) return false; assert(llvm::all_of(ExpectedMask, - [Size](int M) { return isInRange(M, 0, 2 * Size); }) && + [Size](int M) { + return M == SM_SentinelZero || + isInRange(M, 0, 2 * Size); + }) && "Illegal target shuffle mask"); // Check for out-of-range target shuffle mask indices. @@ -10119,6 +10122,9 @@ static bool isTargetShuffleEquivalent(MVT VT, ArrayRef Mask, int ExpectedIdx = ExpectedMask[i]; if (MaskIdx == SM_SentinelUndef || MaskIdx == ExpectedIdx) continue; + // If we failed to match an expected SM_SentinelZero then early out. + if (ExpectedIdx < 0) + return false; if (MaskIdx == SM_SentinelZero) { // If we need this expected index to be a zero element, then update the // relevant zero mask and perform the known bits at the end to minimize @@ -39594,18 +39600,46 @@ static bool matchBinaryPermuteShuffle( ((MaskVT.is128BitVector() && Subtarget.hasVLX()) || (MaskVT.is256BitVector() && Subtarget.hasVLX()) || (MaskVT.is512BitVector() && Subtarget.hasAVX512()))) { + MVT AlignVT = MVT::getVectorVT(MVT::getIntegerVT(EltSizeInBits), + MaskVT.getSizeInBits() / EltSizeInBits); if (!isAnyZero(Mask)) { int Rotation = matchShuffleAsElementRotate(V1, V2, Mask); if (0 < Rotation) { Shuffle = X86ISD::VALIGN; - if (EltSizeInBits == 64) - ShuffleVT = MVT::getVectorVT(MVT::i64, MaskVT.getSizeInBits() / 64); - else - ShuffleVT = MVT::getVectorVT(MVT::i32, MaskVT.getSizeInBits() / 32); + ShuffleVT = AlignVT; PermuteImm = Rotation; return true; } } + // See if we can use VALIGN as a cross-lane version of VSHLDQ/VSRLDQ. + unsigned ZeroLo = Zeroable.countr_one(); + unsigned ZeroHi = Zeroable.countl_one(); + assert((ZeroLo + ZeroHi) < NumMaskElts && "Zeroable shuffle detected"); + if (ZeroLo) { + SmallVector ShiftMask(NumMaskElts, SM_SentinelZero); + std::iota(ShiftMask.begin() + ZeroLo, ShiftMask.end(), 0); + if (isTargetShuffleEquivalent(MaskVT, Mask, ShiftMask, DAG, V1)) { + V1 = V1; + V2 = getZeroVector(AlignVT, Subtarget, DAG, DL); + Shuffle = X86ISD::VALIGN; + ShuffleVT = AlignVT; + PermuteImm = NumMaskElts - ZeroLo; + return true; + } + } + if (ZeroHi) { + SmallVector ShiftMask(NumMaskElts, SM_SentinelZero); + std::iota(ShiftMask.begin(), ShiftMask.begin() + NumMaskElts - ZeroHi, + ZeroHi); + if (isTargetShuffleEquivalent(MaskVT, Mask, ShiftMask, DAG, V1)) { + V2 = V1; + V1 = getZeroVector(AlignVT, Subtarget, DAG, DL); + Shuffle = X86ISD::VALIGN; + ShuffleVT = AlignVT; + PermuteImm = ZeroHi; + return true; + } + } } // Attempt to match against PALIGNR byte rotate. diff --git a/llvm/test/CodeGen/X86/vector-shuffle-combining-avx512f.ll b/llvm/test/CodeGen/X86/vector-shuffle-combining-avx512f.ll index b3b90b5f51501..68967c2ce6536 100644 --- a/llvm/test/CodeGen/X86/vector-shuffle-combining-avx512f.ll +++ b/llvm/test/CodeGen/X86/vector-shuffle-combining-avx512f.ll @@ -812,10 +812,8 @@ define <8 x i64> @combine_vpermt2var_8i64_as_valignq(<8 x i64> %x0, <8 x i64> %x define <8 x i64> @combine_vpermt2var_8i64_as_valignq_zero(<8 x i64> %x0) { ; CHECK-LABEL: combine_vpermt2var_8i64_as_valignq_zero: ; CHECK: # %bb.0: -; CHECK-NEXT: vpmovsxbq {{.*#+}} zmm2 = [15,0,1,2,3,4,5,6] ; CHECK-NEXT: vpxor %xmm1, %xmm1, %xmm1 -; CHECK-NEXT: vpermt2q %zmm0, %zmm2, %zmm1 -; CHECK-NEXT: vmovdqa64 %zmm1, %zmm0 +; CHECK-NEXT: valignq {{.*#+}} zmm0 = zmm0[7],zmm1[0,1,2,3,4,5,6] ; CHECK-NEXT: ret{{[l|q]}} %res0 = call <8 x i64> @llvm.x86.avx512.maskz.vpermt2var.q.512(<8 x i64> , <8 x i64> zeroinitializer, <8 x i64> %x0, i8 -1) ret <8 x i64> %res0 @@ -825,8 +823,7 @@ define <8 x i64> @combine_vpermt2var_8i64_as_zero_valignq(<8 x i64> %x0) { ; CHECK-LABEL: combine_vpermt2var_8i64_as_zero_valignq: ; CHECK: # %bb.0: ; CHECK-NEXT: vpxor %xmm1, %xmm1, %xmm1 -; CHECK-NEXT: vpmovsxbq {{.*#+}} zmm2 = [15,0,1,2,3,4,5,6] -; CHECK-NEXT: vpermt2q %zmm1, %zmm2, %zmm0 +; CHECK-NEXT: valignq {{.*#+}} zmm0 = zmm1[7],zmm0[0,1,2,3,4,5,6] ; CHECK-NEXT: ret{{[l|q]}} %res0 = call <8 x i64> @llvm.x86.avx512.maskz.vpermt2var.q.512(<8 x i64> , <8 x i64> %x0, <8 x i64> zeroinitializer, i8 -1) ret <8 x i64> %res0