@@ -29766,6 +29766,102 @@ static SDValue convertShiftLeftToScale(SDValue Amt, const SDLoc &dl,
2976629766 return SDValue();
2976729767}
2976829768
29769+ // Given a vector of values, find a permutation such that every adjacent even-
29770+ // odd pair has the same value. ~0 is reserved as a special value for wildcard,
29771+ // which can be paired with any value. Returns true if a permutation is found.
29772+ template <typename InputTy,
29773+ typename PermutationTy,
29774+ typename MapTy = std::unordered_map<typename InputTy::value_type,
29775+ std::pair<typename InputTy::value_type, typename PermutationTy::value_type>>>
29776+ static bool PermuteAndPairVector(const InputTy& Inputs,
29777+ PermutationTy &Permutation) {
29778+ const auto Wildcard = ~typename InputTy::value_type();
29779+
29780+ // List of values to be paired, mapping an unpaired value to its current
29781+ // neighbor's value and index.
29782+ MapTy UnpairedInputs;
29783+ SmallVector<typename PermutationTy::value_type, 16> WildcardPairs;
29784+
29785+ Permutation.clear();
29786+ typename PermutationTy::value_type I = 0;
29787+ for (auto InputIt = Inputs.begin(), InputEnd = Inputs.end(); InputIt != InputEnd;) {
29788+ Permutation.push_back(I);
29789+ Permutation.push_back(I + 1);
29790+
29791+ auto Even = *InputIt++;
29792+ assert(InputIt != InputEnd && "Expected even number of elements");
29793+ auto Odd = *InputIt++;
29794+
29795+ // If both are wildcards, note it for later use by unpairable values.
29796+ if (Even == Wildcard && Odd == Wildcard) {
29797+ WildcardPairs.push_back(I);
29798+ }
29799+
29800+ // If both are equal, they are in good position.
29801+ if (Even != Odd) {
29802+ auto DoWork = [&] (auto &This, auto ThisIndex, auto Other, auto OtherIndex) {
29803+ if (This != Wildcard) {
29804+ // For non-wildcard value, check if it can pair with an exisiting
29805+ // unpaired value from UnpairedInputs, if so, swap with the unpaired
29806+ // value's neighbor, otherwise the current value is added to the map.
29807+ if (auto [MapIt, Inserted] = UnpairedInputs.try_emplace(This, std::make_pair(Other, OtherIndex)); !Inserted) {
29808+ auto [SwapValue, SwapIndex] = MapIt->second;
29809+ std::swap(Permutation[SwapIndex], Permutation[ThisIndex]);
29810+ This = SwapValue;
29811+ UnpairedInputs.erase(MapIt);
29812+
29813+ if (This == Other) {
29814+ if (This == Wildcard) {
29815+ // We freed up a wildcard pair by pairing two non-adjacent
29816+ // values, note it for later use by unpairable values.
29817+ WildcardPairs.push_back(I);
29818+ } else {
29819+ // The swapped element also forms a pair with Other, so it can
29820+ // be removed from the map.
29821+ assert(UnpairedInputs.count(This));
29822+ UnpairedInputs.erase(This);
29823+ }
29824+ } else {
29825+ // Swapped in an unpaired value, update its info.
29826+ if (This != Wildcard) {
29827+ assert(UnpairedInputs.count(This));
29828+ UnpairedInputs[This] = std::make_pair(Other, OtherIndex);
29829+ }
29830+ // If its neighbor is also in UnpairedInputs, update its info too.
29831+ if (auto OtherMapIt = UnpairedInputs.find(Other); OtherMapIt != UnpairedInputs.end() && OtherMapIt->second.second == ThisIndex) {
29832+ OtherMapIt->second.first = This;
29833+ }
29834+ }
29835+ }
29836+ }
29837+ };
29838+ DoWork(Even, I, Odd, I + 1);
29839+ if (Even != Odd) {
29840+ DoWork(Odd, I + 1, Even, I);
29841+ }
29842+ }
29843+ I += 2;
29844+ }
29845+
29846+ // Now check if each remaining unpaired neighboring values can be swapped with
29847+ // a wildcard pair to form two paired values.
29848+ for (auto &[Unpaired, V] : UnpairedInputs) {
29849+ auto [Neighbor, NeighborIndex] = V;
29850+ if (Neighbor != Wildcard) {
29851+ assert(UnpairedInputs.count(Neighbor));
29852+ if (WildcardPairs.size()) {
29853+ std::swap(Permutation[WildcardPairs.back()], Permutation[NeighborIndex]);
29854+ WildcardPairs.pop_back();
29855+ // Mark the neighbor as processed.
29856+ UnpairedInputs[Neighbor].first = Wildcard;
29857+ } else {
29858+ return false;
29859+ }
29860+ }
29861+ }
29862+ return true;
29863+ }
29864+
2976929865static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
2977029866 SelectionDAG &DAG) {
2977129867 MVT VT = Op.getSimpleValueType();
@@ -30044,6 +30140,110 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
3004430140 }
3004530141 }
3004630142
30143+ // ISD::SRA/SRL/SHL on vXi8 can be widened to vYi16 (Y = X/2) if the constant
30144+ // amounts can be shuffled such that every pair of adjacent elements has the
30145+ // same value. This introduces an extra shuffle before and after the shift,
30146+ // and it is profitable if the operand is aready a shuffle so that both can
30147+ // be merged, or if the extra shuffle is fast (can use VPSHUFB).
30148+ // (shift (shuffle X P1) S1) ->
30149+ // (shuffle (shift (shuffle X (shuffle P2 P1)) S2) P2^-1) where S2 can be
30150+ // widened, and P2^-1 is the inverse shuffle of P2.
30151+ if (ConstantAmt && (VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8) && R.hasOneUse() && Subtarget.hasSSE3()) {
30152+ bool Profitable = true;
30153+ // VPAND ymm only available on AVX2.
30154+ if (VT == MVT::v32i8 || VT == MVT::v64i8) {
30155+ Profitable = Subtarget.hasAVX2();
30156+ }
30157+
30158+ SmallVector<int, 64> Permutation;
30159+ SmallVector<uint16_t, 64> ShiftAmt;
30160+ for (size_t I = 0; I < Amt.getNumOperands(); ++I) {
30161+ if (Amt.getOperand(I).isUndef())
30162+ ShiftAmt.push_back(~0);
30163+ else
30164+ ShiftAmt.push_back(Amt.getConstantOperandVal(I));
30165+ }
30166+
30167+ if (Profitable && (VT == MVT::v32i8 || VT == MVT::v64i8)) {
30168+ Profitable = false;
30169+ constexpr size_t LaneBytes = 16;
30170+ const size_t NumLanes = VT.getVectorNumElements() / LaneBytes;
30171+
30172+ // For v32i8 or v64i8, we should check if we can generate a shuffle that
30173+ // may be lowered to VPSHUFB, because it is faster than VPERMB. This is
30174+ // possible if we can apply the same shuffle mask to each v16i8 lane.
30175+ // For example (assuming a lane has 4 elements for simplicity),
30176+ // <1, 2, 2, 1, 4, 3, 3, 4> is handled as <14, 23, 23, 14>, which can
30177+ // be shuffled to adjacent pairs <14, 14, 23, 23> with the VPSHUFB mask
30178+ // <0, 3, 2, 1> (or high level mask <0, 3, 2, 1, 4, 7, 6, 5>).
30179+ // Limitation: if there are some undef in shift amounts, this algorithm
30180+ // may not find a solution even if one exists, as here we only treat a
30181+ // VPSHUFB index as undef if all shuffle amounts of the same index modulo
30182+ // lane size are all undef.
30183+ // Since a byte can only be shifted by 7 bits without being UB, 4 bits are
30184+ // enough to represent the shift amount or undef (0xF).
30185+ std::array<uint16_t, LaneBytes> VPSHUFBShiftAmt = {};
30186+ for (size_t I = 0; I < LaneBytes; ++I)
30187+ for (size_t J = 0; J < NumLanes; ++J)
30188+ VPSHUFBShiftAmt[I] |= (ShiftAmt[I + J * LaneBytes] & 0xF) << (J * 4);
30189+ if (VT == MVT::v32i8) {
30190+ for (size_t I = 0; I < LaneBytes; ++I)
30191+ VPSHUFBShiftAmt[I] |= 0xFF00;
30192+ }
30193+ if (PermuteAndPairVector(VPSHUFBShiftAmt, Permutation)) {
30194+ // Found a VPSHUFB solution, offset the shuffle amount to other lanes.
30195+ Permutation.resize(VT.getVectorNumElements());
30196+ for (size_t I = 0; I < LaneBytes; ++I)
30197+ for (size_t J = 1; J < NumLanes; ++J)
30198+ Permutation[I + J * LaneBytes] = Permutation[I] + J * LaneBytes;
30199+ Profitable = true;
30200+ } else if (R.getOpcode() == ISD::VECTOR_SHUFFLE) {
30201+ // A slower shuffle is profitable if the operand is also a slow shuffle,
30202+ // such that they can be merged.
30203+ // TODO: Use TargetTransformInfo to systematically determine whether
30204+ // inner shuffle is slow. Currently we only check if it contains
30205+ // cross-lane shuffle.
30206+ if (ShuffleVectorSDNode *InnerShuffle = dyn_cast<ShuffleVectorSDNode>(R.getNode())) {
30207+ if (InnerShuffle->getMask().size() == VT.getVectorNumElements() &&
30208+ is128BitLaneCrossingShuffleMask(VT, InnerShuffle->getMask()))
30209+ Profitable = true;
30210+ }
30211+ }
30212+ }
30213+
30214+ // If it is still profitable at this point, and has not found a permutation
30215+ // yet, try again with any shuffle index.
30216+ if (Profitable && Permutation.empty()) {
30217+ PermuteAndPairVector<decltype(ShiftAmt), decltype(Permutation),
30218+ SmallMapVector<uint16_t, std::pair<uint16_t, int>, 8>>(ShiftAmt, Permutation);
30219+ }
30220+
30221+ // Found a permutation P that can rearrange the shift amouts into adjacent
30222+ // pair of same values. Rewrite the shift S1(x) into P^-1(S2(P(x))).
30223+ if (!Permutation.empty()) {
30224+ SDValue InnerShuffle = DAG.getVectorShuffle(VT, dl, R, DAG.getUNDEF(VT), Permutation);
30225+ SmallVector<SDValue, 64> NewShiftAmt;
30226+ for (int Index : Permutation) {
30227+ NewShiftAmt.push_back(Amt.getOperand(Index));
30228+ }
30229+ #ifndef NDEBUG
30230+ for (size_t I = 0; I < NewShiftAmt.size(); I += 2) {
30231+ SDValue Even = NewShiftAmt[I];
30232+ SDValue Odd = NewShiftAmt[I + 1];
30233+ assert(Even.isUndef() || Odd.isUndef() || Even->getAsZExtVal() == Odd->getAsZExtVal());
30234+ }
30235+ #endif
30236+ SDValue NewShiftVector = DAG.getBuildVector(VT, dl, NewShiftAmt);
30237+ SDValue NewShift = DAG.getNode(Opc, dl, VT, InnerShuffle, NewShiftVector);
30238+ SmallVector<int, 64> InversePermutation(Permutation.size());
30239+ for (size_t I = 0; I < Permutation.size(); ++I) {
30240+ InversePermutation[Permutation[I]] = I;
30241+ }
30242+ SDValue OuterShuffle = DAG.getVectorShuffle(VT, dl, NewShift, DAG.getUNDEF(VT), InversePermutation);
30243+ return OuterShuffle;
30244+ }
30245+ }
30246+
3004730247 // If possible, lower this packed shift into a vector multiply instead of
3004830248 // expanding it into a sequence of scalar shifts.
3004930249 // For v32i8 cases, it might be quicker to split/extend to vXi16 shifts.
0 commit comments