Skip to content

Commit 9191944

Browse files
committed
[AutoBump] Merge with fixes of 49c5ceb (Sep 16)
2 parents 45c9757 + 49c5ceb commit 9191944

File tree

9 files changed

+568
-113
lines changed

9 files changed

+568
-113
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 142 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -29830,6 +29830,144 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
2983029830
}
2983129831
}
2983229832

29833+
// Constant ISD::SRA/SRL/SHL can be performed efficiently on vXi8 vectors by
29834+
// using vXi16 vector operations.
29835+
if (ConstantAmt &&
29836+
(VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
29837+
(VT == MVT::v64i8 && Subtarget.hasBWI())) &&
29838+
!Subtarget.hasXOP()) {
29839+
int NumElts = VT.getVectorNumElements();
29840+
MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
29841+
// We can do this extra fast if each pair of i8 elements is shifted by the
29842+
// same amount by doing this SWAR style: use a shift to move the valid bits
29843+
// to the right position, mask out any bits which crossed from one element
29844+
// to the other.
29845+
APInt UndefElts;
29846+
SmallVector<APInt, 64> AmtBits;
29847+
// This optimized lowering is only valid if the elements in a pair can
29848+
// be treated identically.
29849+
bool SameShifts = true;
29850+
SmallVector<APInt, 32> AmtBits16(NumElts / 2);
29851+
APInt UndefElts16 = APInt::getZero(AmtBits16.size());
29852+
if (getTargetConstantBitsFromNode(Amt, /*EltSizeInBits=*/8, UndefElts,
29853+
AmtBits, /*AllowWholeUndefs=*/true,
29854+
/*AllowPartialUndefs=*/false)) {
29855+
// Collect information to construct the BUILD_VECTOR for the i16 version
29856+
// of the shift. Conceptually, this is equivalent to:
29857+
// 1. Making sure the shift amounts are the same for both the low i8 and
29858+
// high i8 corresponding to the i16 lane.
29859+
// 2. Extending that shift amount to i16 for a build vector operation.
29860+
//
29861+
// We want to handle undef shift amounts which requires a little more
29862+
// logic (e.g. if one is undef and the other is not, grab the other shift
29863+
// amount).
29864+
for (unsigned SrcI = 0, E = AmtBits.size(); SrcI != E; SrcI += 2) {
29865+
unsigned DstI = SrcI / 2;
29866+
// Both elements are undef? Make a note and keep going.
29867+
if (UndefElts[SrcI] && UndefElts[SrcI + 1]) {
29868+
AmtBits16[DstI] = APInt::getZero(16);
29869+
UndefElts16.setBit(DstI);
29870+
continue;
29871+
}
29872+
// Even element is undef? We will shift it by the same shift amount as
29873+
// the odd element.
29874+
if (UndefElts[SrcI]) {
29875+
AmtBits16[DstI] = AmtBits[SrcI + 1].zext(16);
29876+
continue;
29877+
}
29878+
// Odd element is undef? We will shift it by the same shift amount as
29879+
// the even element.
29880+
if (UndefElts[SrcI + 1]) {
29881+
AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29882+
continue;
29883+
}
29884+
// Both elements are equal.
29885+
if (AmtBits[SrcI] == AmtBits[SrcI + 1]) {
29886+
AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29887+
continue;
29888+
}
29889+
// One of the provisional i16 elements will not have the same shift
29890+
// amount. Let's bail.
29891+
SameShifts = false;
29892+
break;
29893+
}
29894+
}
29895+
// We are only dealing with identical pairs.
29896+
if (SameShifts) {
29897+
// Cast the operand to vXi16.
29898+
SDValue R16 = DAG.getBitcast(VT16, R);
29899+
// Create our new vector of shift amounts.
29900+
SDValue Amt16 = getConstVector(AmtBits16, UndefElts16, VT16, DAG, dl);
29901+
// Perform the actual shift.
29902+
unsigned LogicalOpc = Opc == ISD::SRA ? ISD::SRL : Opc;
29903+
SDValue ShiftedR = DAG.getNode(LogicalOpc, dl, VT16, R16, Amt16);
29904+
// Now we need to construct a mask which will "drop" bits that get
29905+
// shifted past the LSB/MSB. For a logical shift left, it will look
29906+
// like:
29907+
// MaskLowBits = (0xff << Amt16) & 0xff;
29908+
// MaskHighBits = MaskLowBits << 8;
29909+
// Mask = MaskLowBits | MaskHighBits;
29910+
//
29911+
// This masking ensures that bits cannot migrate from one i8 to
29912+
// another. The construction of this mask will be constant folded.
29913+
// The mask for a logical right shift is nearly identical, the only
29914+
// difference is that 0xff is shifted right instead of left.
29915+
SDValue Cst255 = DAG.getConstant(0xff, dl, MVT::i16);
29916+
SDValue Splat255 = DAG.getSplat(VT16, dl, Cst255);
29917+
// The mask for the low bits is most simply expressed as an 8-bit
29918+
// field of all ones which is shifted in the exact same way the data
29919+
// is shifted but masked with 0xff.
29920+
SDValue MaskLowBits = DAG.getNode(LogicalOpc, dl, VT16, Splat255, Amt16);
29921+
MaskLowBits = DAG.getNode(ISD::AND, dl, VT16, MaskLowBits, Splat255);
29922+
SDValue Cst8 = DAG.getConstant(8, dl, MVT::i16);
29923+
SDValue Splat8 = DAG.getSplat(VT16, dl, Cst8);
29924+
// The mask for the high bits is the same as the mask for the low bits but
29925+
// shifted up by 8.
29926+
SDValue MaskHighBits =
29927+
DAG.getNode(ISD::SHL, dl, VT16, MaskLowBits, Splat8);
29928+
SDValue Mask = DAG.getNode(ISD::OR, dl, VT16, MaskLowBits, MaskHighBits);
29929+
// Finally, we mask the shifted vector with the SWAR mask.
29930+
SDValue Masked = DAG.getNode(ISD::AND, dl, VT16, ShiftedR, Mask);
29931+
Masked = DAG.getBitcast(VT, Masked);
29932+
if (Opc != ISD::SRA) {
29933+
// Logical shifts are complete at this point.
29934+
return Masked;
29935+
}
29936+
// At this point, we have done a *logical* shift right. We now need to
29937+
// sign extend the result so that we get behavior equivalent to an
29938+
// arithmetic shift right. Post-shifting by Amt16, our i8 elements are
29939+
// `8-Amt16` bits wide.
29940+
//
29941+
// To convert our `8-Amt16` bit unsigned numbers to 8-bit signed numbers,
29942+
// we need to replicate the bit at position `7-Amt16` into the MSBs of
29943+
// each i8.
29944+
// We can use the following trick to accomplish this:
29945+
// SignBitMask = 1 << (7-Amt16)
29946+
// (Masked ^ SignBitMask) - SignBitMask
29947+
//
29948+
// When the sign bit is already clear, this will compute:
29949+
// Masked + SignBitMask - SignBitMask
29950+
//
29951+
// This is equal to Masked which is what we want: the sign bit was clear
29952+
// so sign extending should be a no-op.
29953+
//
29954+
// When the sign bit is set, this will compute:
29955+
// Masked - SignBitmask - SignBitMask
29956+
//
29957+
// This is equal to Masked - 2*SignBitMask which will correctly sign
29958+
// extend our result.
29959+
SDValue CstHighBit = DAG.getConstant(0x80, dl, MVT::i8);
29960+
SDValue SplatHighBit = DAG.getSplat(VT, dl, CstHighBit);
29961+
// This does not induce recursion, all operands are constants.
29962+
SDValue SignBitMask = DAG.getNode(LogicalOpc, dl, VT, SplatHighBit, Amt);
29963+
SDValue FlippedSignBit =
29964+
DAG.getNode(ISD::XOR, dl, VT, Masked, SignBitMask);
29965+
SDValue Subtraction =
29966+
DAG.getNode(ISD::SUB, dl, VT, FlippedSignBit, SignBitMask);
29967+
return Subtraction;
29968+
}
29969+
}
29970+
2983329971
// If possible, lower this packed shift into a vector multiply instead of
2983429972
// expanding it into a sequence of scalar shifts.
2983529973
// For v32i8 cases, it might be quicker to split/extend to vXi16 shifts.
@@ -29950,105 +30088,18 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
2995030088
DAG.getNode(Opc, dl, ExtVT, R, Amt));
2995130089
}
2995230090

29953-
// Constant ISD::SRA/SRL can be performed efficiently on vXi8 vectors by using
29954-
// vXi16 vector operations.
30091+
// Constant ISD::SRA/SRL can be performed efficiently on vXi8 vectors as we
30092+
// extend to vXi16 to perform a MUL scale effectively as a MUL_LOHI.
2995530093
if (ConstantAmt && (Opc == ISD::SRA || Opc == ISD::SRL) &&
2995630094
(VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
2995730095
(VT == MVT::v64i8 && Subtarget.hasBWI())) &&
2995830096
!Subtarget.hasXOP()) {
2995930097
int NumElts = VT.getVectorNumElements();
2996030098
MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
29961-
// We can do this extra fast if each pair of i8 elements is shifted by the
29962-
// same amount by doing this SWAR style: use a shift to move the valid bits
29963-
// to the right position, mask out any bits which crossed from one element
29964-
// to the other.
29965-
if (Opc == ISD::SRL || Opc == ISD::SHL) {
29966-
APInt UndefElts;
29967-
SmallVector<APInt, 64> AmtBits;
29968-
if (getTargetConstantBitsFromNode(Amt, /*EltSizeInBits=*/8, UndefElts,
29969-
AmtBits, /*AllowWholeUndefs=*/true,
29970-
/*AllowPartialUndefs=*/false)) {
29971-
// This optimized lowering is only valid if the elements in a pair can
29972-
// be treated identically.
29973-
bool SameShifts = true;
29974-
SmallVector<APInt, 32> AmtBits16(NumElts / 2);
29975-
APInt UndefElts16 = APInt::getZero(AmtBits16.size());
29976-
for (unsigned SrcI = 0, E = AmtBits.size(); SrcI != E; SrcI += 2) {
29977-
unsigned DstI = SrcI / 2;
29978-
// Both elements are undef? Make a note and keep going.
29979-
if (UndefElts[SrcI] && UndefElts[SrcI + 1]) {
29980-
AmtBits16[DstI] = APInt::getZero(16);
29981-
UndefElts16.setBit(DstI);
29982-
continue;
29983-
}
29984-
// Even element is undef? We will shift it by the same shift amount as
29985-
// the odd element.
29986-
if (UndefElts[SrcI]) {
29987-
AmtBits16[DstI] = AmtBits[SrcI + 1].zext(16);
29988-
continue;
29989-
}
29990-
// Odd element is undef? We will shift it by the same shift amount as
29991-
// the even element.
29992-
if (UndefElts[SrcI + 1]) {
29993-
AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29994-
continue;
29995-
}
29996-
// Both elements are equal.
29997-
if (AmtBits[SrcI] == AmtBits[SrcI + 1]) {
29998-
AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29999-
continue;
30000-
}
30001-
// One of the provisional i16 elements will not have the same shift
30002-
// amount. Let's bail.
30003-
SameShifts = false;
30004-
break;
30005-
}
30006-
30007-
// We are only dealing with identical pairs and the operation is a
30008-
// logical shift.
30009-
if (SameShifts) {
30010-
// Cast the operand to vXi16.
30011-
SDValue R16 = DAG.getBitcast(VT16, R);
30012-
// Create our new vector of shift amounts.
30013-
SDValue Amt16 = getConstVector(AmtBits16, UndefElts16, VT16, DAG, dl);
30014-
// Perform the actual shift.
30015-
SDValue ShiftedR = DAG.getNode(Opc, dl, VT16, R16, Amt16);
30016-
// Now we need to construct a mask which will "drop" bits that get
30017-
// shifted past the LSB/MSB. For a logical shift left, it will look
30018-
// like:
30019-
// MaskLowBits = (0xff << Amt16) & 0xff;
30020-
// MaskHighBits = MaskLowBits << 8;
30021-
// Mask = MaskLowBits | MaskHighBits;
30022-
//
30023-
// This masking ensures that bits cannot migrate from one i8 to
30024-
// another. The construction of this mask will be constant folded.
30025-
// The mask for a logical right shift is nearly identical, the only
30026-
// difference is that 0xff is shifted right instead of left.
30027-
SDValue Cst255 = DAG.getConstant(0xff, dl, MVT::i16);
30028-
SDValue Splat255 = DAG.getSplat(VT16, dl, Cst255);
30029-
// The mask for the low bits is most simply expressed as an 8-bit
30030-
// field of all ones which is shifted in the exact same way the data
30031-
// is shifted but masked with 0xff.
30032-
SDValue MaskLowBits = DAG.getNode(Opc, dl, VT16, Splat255, Amt16);
30033-
MaskLowBits = DAG.getNode(ISD::AND, dl, VT16, MaskLowBits, Splat255);
30034-
SDValue Cst8 = DAG.getConstant(8, dl, MVT::i16);
30035-
SDValue Splat8 = DAG.getSplat(VT16, dl, Cst8);
30036-
// Thie mask for the high bits is the same as the mask for the low
30037-
// bits but shifted up by 8.
30038-
SDValue MaskHighBits =
30039-
DAG.getNode(ISD::SHL, dl, VT16, MaskLowBits, Splat8);
30040-
SDValue Mask =
30041-
DAG.getNode(ISD::OR, dl, VT16, MaskLowBits, MaskHighBits);
30042-
// Finally, we mask the shifted vector with the SWAR mask.
30043-
SDValue Masked = DAG.getNode(ISD::AND, dl, VT16, ShiftedR, Mask);
30044-
return DAG.getBitcast(VT, Masked);
30045-
}
30046-
}
30047-
}
3004830099
SDValue Cst8 = DAG.getTargetConstant(8, dl, MVT::i8);
3004930100

30050-
// Extend to vXi16 to perform a MUL scale effectively as a MUL_LOHI (it
30051-
// doesn't matter if the type isn't legal).
30101+
// Extend constant shift amount to vXi16 (it doesn't matter if the type
30102+
// isn't legal).
3005230103
MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts);
3005330104
Amt = DAG.getZExtOrTrunc(Amt, dl, ExVT);
3005430105
Amt = DAG.getNode(ISD::SUB, dl, ExVT, DAG.getConstant(8, dl, ExVT), Amt);

llvm/test/CodeGen/X86/vector-shift-ashr-128.ll

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,99 @@ define <16 x i8> @constant_shift_v16i8(<16 x i8> %a) nounwind {
15861586
ret <16 x i8> %shift
15871587
}
15881588

1589+
define <16 x i8> @constant_shift_v16i8_pairs(<16 x i8> %a) nounwind {
1590+
; SSE2-LABEL: constant_shift_v16i8_pairs:
1591+
; SSE2: # %bb.0:
1592+
; SSE2-NEXT: movdqa {{.*#+}} xmm1 = [65535,65535,65535,65535,65535,0,65535,65535]
1593+
; SSE2-NEXT: pandn %xmm0, %xmm1
1594+
; SSE2-NEXT: pmulhuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
1595+
; SSE2-NEXT: por %xmm1, %xmm0
1596+
; SSE2-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
1597+
; SSE2-NEXT: movdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1598+
; SSE2-NEXT: pxor %xmm1, %xmm0
1599+
; SSE2-NEXT: psubb %xmm1, %xmm0
1600+
; SSE2-NEXT: retq
1601+
;
1602+
; SSE41-LABEL: constant_shift_v16i8_pairs:
1603+
; SSE41: # %bb.0:
1604+
; SSE41-NEXT: movdqa {{.*#+}} xmm1 = [32768,4096,512,8192,16384,u,2048,1024]
1605+
; SSE41-NEXT: pmulhuw %xmm0, %xmm1
1606+
; SSE41-NEXT: pblendw {{.*#+}} xmm0 = xmm1[0,1,2,3,4],xmm0[5],xmm1[6,7]
1607+
; SSE41-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
1608+
; SSE41-NEXT: movdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1609+
; SSE41-NEXT: pxor %xmm1, %xmm0
1610+
; SSE41-NEXT: psubb %xmm1, %xmm0
1611+
; SSE41-NEXT: retq
1612+
;
1613+
; AVX-LABEL: constant_shift_v16i8_pairs:
1614+
; AVX: # %bb.0:
1615+
; AVX-NEXT: vpmulhuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 # [32768,4096,512,8192,16384,u,2048,1024]
1616+
; AVX-NEXT: vpblendw {{.*#+}} xmm0 = xmm1[0,1,2,3,4],xmm0[5],xmm1[6,7]
1617+
; AVX-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1618+
; AVX-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1619+
; AVX-NEXT: vpxor %xmm1, %xmm0, %xmm0
1620+
; AVX-NEXT: vpsubb %xmm1, %xmm0, %xmm0
1621+
; AVX-NEXT: retq
1622+
;
1623+
; XOP-LABEL: constant_shift_v16i8_pairs:
1624+
; XOP: # %bb.0:
1625+
; XOP-NEXT: vpshab {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1626+
; XOP-NEXT: retq
1627+
;
1628+
; AVX512DQ-LABEL: constant_shift_v16i8_pairs:
1629+
; AVX512DQ: # %bb.0:
1630+
; AVX512DQ-NEXT: vpmulhuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 # [32768,4096,512,8192,16384,u,2048,1024]
1631+
; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm0 = xmm1[0,1,2,3,4],xmm0[5],xmm1[6,7]
1632+
; AVX512DQ-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1633+
; AVX512DQ-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1634+
; AVX512DQ-NEXT: vpxor %xmm1, %xmm0, %xmm0
1635+
; AVX512DQ-NEXT: vpsubb %xmm1, %xmm0, %xmm0
1636+
; AVX512DQ-NEXT: retq
1637+
;
1638+
; AVX512BW-LABEL: constant_shift_v16i8_pairs:
1639+
; AVX512BW: # %bb.0:
1640+
; AVX512BW-NEXT: # kill: def $xmm0 killed $xmm0 def $zmm0
1641+
; AVX512BW-NEXT: vpmovsxbw {{.*#+}} xmm1 = [1,4,7,3,2,0,5,6]
1642+
; AVX512BW-NEXT: vpsrlvw %zmm1, %zmm0, %zmm0
1643+
; AVX512BW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1644+
; AVX512BW-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1645+
; AVX512BW-NEXT: vpxor %xmm1, %xmm0, %xmm0
1646+
; AVX512BW-NEXT: vpsubb %xmm1, %xmm0, %xmm0
1647+
; AVX512BW-NEXT: vzeroupper
1648+
; AVX512BW-NEXT: retq
1649+
;
1650+
; AVX512DQVL-LABEL: constant_shift_v16i8_pairs:
1651+
; AVX512DQVL: # %bb.0:
1652+
; AVX512DQVL-NEXT: vpmulhuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 # [32768,4096,512,8192,16384,u,2048,1024]
1653+
; AVX512DQVL-NEXT: vpblendw {{.*#+}} xmm0 = xmm1[0,1,2,3,4],xmm0[5],xmm1[6,7]
1654+
; AVX512DQVL-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1655+
; AVX512DQVL-NEXT: vpternlogq $108, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
1656+
; AVX512DQVL-NEXT: vpsubb %xmm1, %xmm0, %xmm0
1657+
; AVX512DQVL-NEXT: retq
1658+
;
1659+
; AVX512BWVL-LABEL: constant_shift_v16i8_pairs:
1660+
; AVX512BWVL: # %bb.0:
1661+
; AVX512BWVL-NEXT: vpsrlvw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1662+
; AVX512BWVL-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1663+
; AVX512BWVL-NEXT: vpternlogq $108, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
1664+
; AVX512BWVL-NEXT: vpsubb %xmm1, %xmm0, %xmm0
1665+
; AVX512BWVL-NEXT: retq
1666+
;
1667+
; X86-SSE-LABEL: constant_shift_v16i8_pairs:
1668+
; X86-SSE: # %bb.0:
1669+
; X86-SSE-NEXT: movdqa {{.*#+}} xmm1 = [65535,65535,65535,65535,65535,0,65535,65535]
1670+
; X86-SSE-NEXT: pandn %xmm0, %xmm1
1671+
; X86-SSE-NEXT: pmulhuw {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0
1672+
; X86-SSE-NEXT: por %xmm1, %xmm0
1673+
; X86-SSE-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0
1674+
; X86-SSE-NEXT: movdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1675+
; X86-SSE-NEXT: pxor %xmm1, %xmm0
1676+
; X86-SSE-NEXT: psubb %xmm1, %xmm0
1677+
; X86-SSE-NEXT: retl
1678+
%shift = ashr <16 x i8> %a, <i8 1, i8 1, i8 4, i8 4, i8 7, i8 7, i8 3, i8 3, i8 2, i8 2, i8 0, i8 0, i8 5, i8 5, i8 6, i8 6>
1679+
ret <16 x i8> %shift
1680+
}
1681+
15891682
;
15901683
; Uniform Constant Shifts
15911684
;

0 commit comments

Comments
 (0)