@@ -29830,6 +29830,144 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
2983029830 }
2983129831 }
2983229832
29833+ // Constant ISD::SRA/SRL/SHL can be performed efficiently on vXi8 vectors by
29834+ // using vXi16 vector operations.
29835+ if (ConstantAmt &&
29836+ (VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
29837+ (VT == MVT::v64i8 && Subtarget.hasBWI())) &&
29838+ !Subtarget.hasXOP()) {
29839+ int NumElts = VT.getVectorNumElements();
29840+ MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
29841+ // We can do this extra fast if each pair of i8 elements is shifted by the
29842+ // same amount by doing this SWAR style: use a shift to move the valid bits
29843+ // to the right position, mask out any bits which crossed from one element
29844+ // to the other.
29845+ APInt UndefElts;
29846+ SmallVector<APInt, 64> AmtBits;
29847+ // This optimized lowering is only valid if the elements in a pair can
29848+ // be treated identically.
29849+ bool SameShifts = true;
29850+ SmallVector<APInt, 32> AmtBits16(NumElts / 2);
29851+ APInt UndefElts16 = APInt::getZero(AmtBits16.size());
29852+ if (getTargetConstantBitsFromNode(Amt, /*EltSizeInBits=*/8, UndefElts,
29853+ AmtBits, /*AllowWholeUndefs=*/true,
29854+ /*AllowPartialUndefs=*/false)) {
29855+ // Collect information to construct the BUILD_VECTOR for the i16 version
29856+ // of the shift. Conceptually, this is equivalent to:
29857+ // 1. Making sure the shift amounts are the same for both the low i8 and
29858+ // high i8 corresponding to the i16 lane.
29859+ // 2. Extending that shift amount to i16 for a build vector operation.
29860+ //
29861+ // We want to handle undef shift amounts which requires a little more
29862+ // logic (e.g. if one is undef and the other is not, grab the other shift
29863+ // amount).
29864+ for (unsigned SrcI = 0, E = AmtBits.size(); SrcI != E; SrcI += 2) {
29865+ unsigned DstI = SrcI / 2;
29866+ // Both elements are undef? Make a note and keep going.
29867+ if (UndefElts[SrcI] && UndefElts[SrcI + 1]) {
29868+ AmtBits16[DstI] = APInt::getZero(16);
29869+ UndefElts16.setBit(DstI);
29870+ continue;
29871+ }
29872+ // Even element is undef? We will shift it by the same shift amount as
29873+ // the odd element.
29874+ if (UndefElts[SrcI]) {
29875+ AmtBits16[DstI] = AmtBits[SrcI + 1].zext(16);
29876+ continue;
29877+ }
29878+ // Odd element is undef? We will shift it by the same shift amount as
29879+ // the even element.
29880+ if (UndefElts[SrcI + 1]) {
29881+ AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29882+ continue;
29883+ }
29884+ // Both elements are equal.
29885+ if (AmtBits[SrcI] == AmtBits[SrcI + 1]) {
29886+ AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29887+ continue;
29888+ }
29889+ // One of the provisional i16 elements will not have the same shift
29890+ // amount. Let's bail.
29891+ SameShifts = false;
29892+ break;
29893+ }
29894+ }
29895+ // We are only dealing with identical pairs.
29896+ if (SameShifts) {
29897+ // Cast the operand to vXi16.
29898+ SDValue R16 = DAG.getBitcast(VT16, R);
29899+ // Create our new vector of shift amounts.
29900+ SDValue Amt16 = getConstVector(AmtBits16, UndefElts16, VT16, DAG, dl);
29901+ // Perform the actual shift.
29902+ unsigned LogicalOpc = Opc == ISD::SRA ? ISD::SRL : Opc;
29903+ SDValue ShiftedR = DAG.getNode(LogicalOpc, dl, VT16, R16, Amt16);
29904+ // Now we need to construct a mask which will "drop" bits that get
29905+ // shifted past the LSB/MSB. For a logical shift left, it will look
29906+ // like:
29907+ // MaskLowBits = (0xff << Amt16) & 0xff;
29908+ // MaskHighBits = MaskLowBits << 8;
29909+ // Mask = MaskLowBits | MaskHighBits;
29910+ //
29911+ // This masking ensures that bits cannot migrate from one i8 to
29912+ // another. The construction of this mask will be constant folded.
29913+ // The mask for a logical right shift is nearly identical, the only
29914+ // difference is that 0xff is shifted right instead of left.
29915+ SDValue Cst255 = DAG.getConstant(0xff, dl, MVT::i16);
29916+ SDValue Splat255 = DAG.getSplat(VT16, dl, Cst255);
29917+ // The mask for the low bits is most simply expressed as an 8-bit
29918+ // field of all ones which is shifted in the exact same way the data
29919+ // is shifted but masked with 0xff.
29920+ SDValue MaskLowBits = DAG.getNode(LogicalOpc, dl, VT16, Splat255, Amt16);
29921+ MaskLowBits = DAG.getNode(ISD::AND, dl, VT16, MaskLowBits, Splat255);
29922+ SDValue Cst8 = DAG.getConstant(8, dl, MVT::i16);
29923+ SDValue Splat8 = DAG.getSplat(VT16, dl, Cst8);
29924+ // The mask for the high bits is the same as the mask for the low bits but
29925+ // shifted up by 8.
29926+ SDValue MaskHighBits =
29927+ DAG.getNode(ISD::SHL, dl, VT16, MaskLowBits, Splat8);
29928+ SDValue Mask = DAG.getNode(ISD::OR, dl, VT16, MaskLowBits, MaskHighBits);
29929+ // Finally, we mask the shifted vector with the SWAR mask.
29930+ SDValue Masked = DAG.getNode(ISD::AND, dl, VT16, ShiftedR, Mask);
29931+ Masked = DAG.getBitcast(VT, Masked);
29932+ if (Opc != ISD::SRA) {
29933+ // Logical shifts are complete at this point.
29934+ return Masked;
29935+ }
29936+ // At this point, we have done a *logical* shift right. We now need to
29937+ // sign extend the result so that we get behavior equivalent to an
29938+ // arithmetic shift right. Post-shifting by Amt16, our i8 elements are
29939+ // `8-Amt16` bits wide.
29940+ //
29941+ // To convert our `8-Amt16` bit unsigned numbers to 8-bit signed numbers,
29942+ // we need to replicate the bit at position `7-Amt16` into the MSBs of
29943+ // each i8.
29944+ // We can use the following trick to accomplish this:
29945+ // SignBitMask = 1 << (7-Amt16)
29946+ // (Masked ^ SignBitMask) - SignBitMask
29947+ //
29948+ // When the sign bit is already clear, this will compute:
29949+ // Masked + SignBitMask - SignBitMask
29950+ //
29951+ // This is equal to Masked which is what we want: the sign bit was clear
29952+ // so sign extending should be a no-op.
29953+ //
29954+ // When the sign bit is set, this will compute:
29955+ // Masked - SignBitmask - SignBitMask
29956+ //
29957+ // This is equal to Masked - 2*SignBitMask which will correctly sign
29958+ // extend our result.
29959+ SDValue CstHighBit = DAG.getConstant(0x80, dl, MVT::i8);
29960+ SDValue SplatHighBit = DAG.getSplat(VT, dl, CstHighBit);
29961+ // This does not induce recursion, all operands are constants.
29962+ SDValue SignBitMask = DAG.getNode(LogicalOpc, dl, VT, SplatHighBit, Amt);
29963+ SDValue FlippedSignBit =
29964+ DAG.getNode(ISD::XOR, dl, VT, Masked, SignBitMask);
29965+ SDValue Subtraction =
29966+ DAG.getNode(ISD::SUB, dl, VT, FlippedSignBit, SignBitMask);
29967+ return Subtraction;
29968+ }
29969+ }
29970+
2983329971 // If possible, lower this packed shift into a vector multiply instead of
2983429972 // expanding it into a sequence of scalar shifts.
2983529973 // For v32i8 cases, it might be quicker to split/extend to vXi16 shifts.
@@ -29950,105 +30088,18 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
2995030088 DAG.getNode(Opc, dl, ExtVT, R, Amt));
2995130089 }
2995230090
29953- // Constant ISD::SRA/SRL can be performed efficiently on vXi8 vectors by using
29954- // vXi16 vector operations .
30091+ // Constant ISD::SRA/SRL can be performed efficiently on vXi8 vectors as we
30092+ // extend to vXi16 to perform a MUL scale effectively as a MUL_LOHI .
2995530093 if (ConstantAmt && (Opc == ISD::SRA || Opc == ISD::SRL) &&
2995630094 (VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
2995730095 (VT == MVT::v64i8 && Subtarget.hasBWI())) &&
2995830096 !Subtarget.hasXOP()) {
2995930097 int NumElts = VT.getVectorNumElements();
2996030098 MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
29961- // We can do this extra fast if each pair of i8 elements is shifted by the
29962- // same amount by doing this SWAR style: use a shift to move the valid bits
29963- // to the right position, mask out any bits which crossed from one element
29964- // to the other.
29965- if (Opc == ISD::SRL || Opc == ISD::SHL) {
29966- APInt UndefElts;
29967- SmallVector<APInt, 64> AmtBits;
29968- if (getTargetConstantBitsFromNode(Amt, /*EltSizeInBits=*/8, UndefElts,
29969- AmtBits, /*AllowWholeUndefs=*/true,
29970- /*AllowPartialUndefs=*/false)) {
29971- // This optimized lowering is only valid if the elements in a pair can
29972- // be treated identically.
29973- bool SameShifts = true;
29974- SmallVector<APInt, 32> AmtBits16(NumElts / 2);
29975- APInt UndefElts16 = APInt::getZero(AmtBits16.size());
29976- for (unsigned SrcI = 0, E = AmtBits.size(); SrcI != E; SrcI += 2) {
29977- unsigned DstI = SrcI / 2;
29978- // Both elements are undef? Make a note and keep going.
29979- if (UndefElts[SrcI] && UndefElts[SrcI + 1]) {
29980- AmtBits16[DstI] = APInt::getZero(16);
29981- UndefElts16.setBit(DstI);
29982- continue;
29983- }
29984- // Even element is undef? We will shift it by the same shift amount as
29985- // the odd element.
29986- if (UndefElts[SrcI]) {
29987- AmtBits16[DstI] = AmtBits[SrcI + 1].zext(16);
29988- continue;
29989- }
29990- // Odd element is undef? We will shift it by the same shift amount as
29991- // the even element.
29992- if (UndefElts[SrcI + 1]) {
29993- AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29994- continue;
29995- }
29996- // Both elements are equal.
29997- if (AmtBits[SrcI] == AmtBits[SrcI + 1]) {
29998- AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29999- continue;
30000- }
30001- // One of the provisional i16 elements will not have the same shift
30002- // amount. Let's bail.
30003- SameShifts = false;
30004- break;
30005- }
30006-
30007- // We are only dealing with identical pairs and the operation is a
30008- // logical shift.
30009- if (SameShifts) {
30010- // Cast the operand to vXi16.
30011- SDValue R16 = DAG.getBitcast(VT16, R);
30012- // Create our new vector of shift amounts.
30013- SDValue Amt16 = getConstVector(AmtBits16, UndefElts16, VT16, DAG, dl);
30014- // Perform the actual shift.
30015- SDValue ShiftedR = DAG.getNode(Opc, dl, VT16, R16, Amt16);
30016- // Now we need to construct a mask which will "drop" bits that get
30017- // shifted past the LSB/MSB. For a logical shift left, it will look
30018- // like:
30019- // MaskLowBits = (0xff << Amt16) & 0xff;
30020- // MaskHighBits = MaskLowBits << 8;
30021- // Mask = MaskLowBits | MaskHighBits;
30022- //
30023- // This masking ensures that bits cannot migrate from one i8 to
30024- // another. The construction of this mask will be constant folded.
30025- // The mask for a logical right shift is nearly identical, the only
30026- // difference is that 0xff is shifted right instead of left.
30027- SDValue Cst255 = DAG.getConstant(0xff, dl, MVT::i16);
30028- SDValue Splat255 = DAG.getSplat(VT16, dl, Cst255);
30029- // The mask for the low bits is most simply expressed as an 8-bit
30030- // field of all ones which is shifted in the exact same way the data
30031- // is shifted but masked with 0xff.
30032- SDValue MaskLowBits = DAG.getNode(Opc, dl, VT16, Splat255, Amt16);
30033- MaskLowBits = DAG.getNode(ISD::AND, dl, VT16, MaskLowBits, Splat255);
30034- SDValue Cst8 = DAG.getConstant(8, dl, MVT::i16);
30035- SDValue Splat8 = DAG.getSplat(VT16, dl, Cst8);
30036- // Thie mask for the high bits is the same as the mask for the low
30037- // bits but shifted up by 8.
30038- SDValue MaskHighBits =
30039- DAG.getNode(ISD::SHL, dl, VT16, MaskLowBits, Splat8);
30040- SDValue Mask =
30041- DAG.getNode(ISD::OR, dl, VT16, MaskLowBits, MaskHighBits);
30042- // Finally, we mask the shifted vector with the SWAR mask.
30043- SDValue Masked = DAG.getNode(ISD::AND, dl, VT16, ShiftedR, Mask);
30044- return DAG.getBitcast(VT, Masked);
30045- }
30046- }
30047- }
3004830099 SDValue Cst8 = DAG.getTargetConstant(8, dl, MVT::i8);
3004930100
30050- // Extend to vXi16 to perform a MUL scale effectively as a MUL_LOHI (it
30051- // doesn't matter if the type isn't legal).
30101+ // Extend constant shift amount to vXi16 (it doesn't matter if the type
30102+ // isn't legal).
3005230103 MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts);
3005330104 Amt = DAG.getZExtOrTrunc(Amt, dl, ExVT);
3005430105 Amt = DAG.getNode(ISD::SUB, dl, ExVT, DAG.getConstant(8, dl, ExVT), Amt);
0 commit comments