Skip to content
237 changes: 237 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29766,6 +29766,113 @@ static SDValue convertShiftLeftToScale(SDValue Amt, const SDLoc &dl,
return SDValue();
}

// Given a vector of values, find a permutation such that every adjacent even-
// odd pair has the same value. ~0 is reserved as a special value for wildcard,
// which can be paired with any value. Returns true if a permutation is found.
// If output Permutation is not empty, permutation index starts at its previous
// size, so that this function can concatenate the result of multiple calls.
// UnpairedInputs contains values yet to be paired, mapping an unpaired value to
// its current neighbor's value and index.
// Do not use llvm::DenseMap as ~0 is reserved key.
template <typename InputTy, typename PermutationTy,
typename MapTy =
SmallMapVector<typename InputTy::value_type,
std::pair<typename InputTy::value_type,
typename PermutationTy::value_type>,
8>>
static bool PermuteAndPairVector(
const InputTy &Inputs, PermutationTy &Permutation,
MapTy UnpairedInputs = MapTy()) {
const auto Wildcard = ~typename InputTy::value_type();
SmallVector<typename PermutationTy::value_type, 16> WildcardPairs;

size_t OutputOffset = Permutation.size();
typename PermutationTy::value_type I = 0;
for (auto InputIt = Inputs.begin(), InputEnd = Inputs.end();
InputIt != InputEnd;) {
Permutation.push_back(OutputOffset + I);
Permutation.push_back(OutputOffset + I + 1);

auto Even = *InputIt++;
assert(InputIt != InputEnd && "Expected even number of elements");
auto Odd = *InputIt++;

// If both are wildcards, note it for later use by unpairable values.
if (Even == Wildcard && Odd == Wildcard) {
WildcardPairs.push_back(I);
}

// If both are equal, they are in good position.
if (Even != Odd) {
auto DoWork = [&](auto &This, auto ThisIndex, auto Other,
auto OtherIndex) {
if (This != Wildcard) {
// For non-wildcard value, check if it can pair with an exisiting
// unpaired value from UnpairedInputs, if so, swap with the unpaired
// value's neighbor, otherwise the current value is added to the map.
if (auto [MapIt, Inserted] = UnpairedInputs.try_emplace(
This, std::make_pair(Other, OtherIndex));
!Inserted) {
auto [SwapValue, SwapIndex] = MapIt->second;
std::swap(Permutation[OutputOffset + SwapIndex],
Permutation[OutputOffset + ThisIndex]);
This = SwapValue;
UnpairedInputs.erase(MapIt);

if (This == Other) {
if (This == Wildcard) {
// We freed up a wildcard pair by pairing two non-adjacent
// values, note it for later use by unpairable values.
WildcardPairs.push_back(I);
} else {
// The swapped element also forms a pair with Other, so it can
// be removed from the map.
assert(UnpairedInputs.count(This));
UnpairedInputs.erase(This);
}
} else {
// Swapped in an unpaired value, update its info.
if (This != Wildcard) {
assert(UnpairedInputs.count(This));
UnpairedInputs[This] = std::make_pair(Other, OtherIndex);
}
// If its neighbor is also in UnpairedInputs, update its info too.
if (auto OtherMapIt = UnpairedInputs.find(Other);
OtherMapIt != UnpairedInputs.end() &&
OtherMapIt->second.second == ThisIndex) {
OtherMapIt->second.first = This;
}
}
}
}
};
DoWork(Even, I, Odd, I + 1);
if (Even != Odd) {
DoWork(Odd, I + 1, Even, I);
}
}
I += 2;
}

// Now check if each remaining unpaired neighboring values can be swapped with
// a wildcard pair to form two paired values.
for (auto &[Unpaired, V] : UnpairedInputs) {
auto [Neighbor, NeighborIndex] = V;
if (Neighbor != Wildcard) {
assert(UnpairedInputs.count(Neighbor));
if (WildcardPairs.size()) {
std::swap(Permutation[OutputOffset + WildcardPairs.back()],
Permutation[OutputOffset + NeighborIndex]);
WildcardPairs.pop_back();
// Mark the neighbor as processed.
UnpairedInputs[Neighbor].first = Wildcard;
} else
return false;
}
}
return true;
}

static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
MVT VT = Op.getSimpleValueType();
Expand Down Expand Up @@ -30044,6 +30151,136 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
}
}

// SHL/SRL/SRA on vXi8 can be widened to vYi16 or vYi32 if the constant
// amounts can be shuffled such that every pair or quad of adjacent elements
// has the same value. This introduces an extra shuffle before and after the
// shift, and it is profitable if the operand is aready a shuffle so that both
// can be merged and the extra shuffle is fast. This is not profitable on
// AVX512 becasue it has 16-bit vector variable shift instruction VPS**VW.
// (shift (shuffle X P1) S1) ->
// (shuffle (shift (shuffle X (shuffle P2 P1)) S2) P2^-1) where S2 can be
// widened, and P2^-1 is the inverse shuffle of P2.
if (ConstantAmt &&
(VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8) &&
R.hasOneUse() && Subtarget.hasSSE3() && !Subtarget.hasAVX512()) {
constexpr size_t LaneBytes = 16;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lowering scheme immediately above this is very similar to what you're doing (and a lot easier to grok) - I'd recommend you look at extending that code instead of introducing this separate implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code above is to handle shift widening when adjacent pairs have same shift amount. My patch tries to find a permutation to create such shift, but does not perform widening itself (and hand it to the code above), so it is in fact a different functionality and better left in a separate section

const size_t NumLanes = VT.getVectorNumElements() / LaneBytes;

SmallVector<int, 64> Permutation;
SmallVector<uint8_t, 64> ShiftAmt;
for (size_t I = 0; I < Amt.getNumOperands(); ++I) {
if (Amt.getOperand(I).isUndef())
ShiftAmt.push_back(~0);
else
ShiftAmt.push_back(Amt.getConstantOperandVal(I));
}

// Check if we can find an in-lane shuffle to rearrange the shift amounts,
// if so, this transformation may be profitable.
bool Profitable;
for (size_t I = 0; I < NumLanes; ++I) {
if (!(Profitable = PermuteAndPairVector(
ArrayRef(&ShiftAmt[I * LaneBytes], LaneBytes), Permutation)))
break;
}

// For AVX2, check if we can further rearrange shift amounts into adjacent
// quads, so that it can use VPS*LVD instead of VPMUL*W as it is 2 cycles
// faster.
bool IsAdjacentQuads = false;
if (Profitable && Subtarget.hasAVX2()) {
SmallVector<uint8_t, 64> EveryOtherShiftAmt;
for (size_t I = 0; I < Permutation.size(); I += 2) {
uint8_t Shift1 = ShiftAmt[Permutation[I]];
uint8_t Shift2 = ShiftAmt[Permutation[I + 1]];
assert(Shift1 == Shift2 || ~Shift1 == 0 || ~Shift2 == 0);
EveryOtherShiftAmt.push_back(~Shift1 ? Shift1 : Shift2);
}
SmallVector<int, 32> Permutation2;
for (size_t I = 0; I < NumLanes; ++I) {
if (!(IsAdjacentQuads = PermuteAndPairVector(
ArrayRef(&EveryOtherShiftAmt[I * LaneBytes / 2],
LaneBytes / 2),
Permutation2)))
break;
}
if (IsAdjacentQuads) {
SmallVector<int, 64> CombinedPermutation;
for (int Index : Permutation2) {
CombinedPermutation.push_back(Permutation[Index * 2]);
CombinedPermutation.push_back(Permutation[Index * 2 + 1]);
}
std::swap(Permutation, CombinedPermutation);
}
}

// For right shifts, (V)PMULHUW needs an extra instruction to handle an
// amount of 0, disabling the transformation here to be cautious.
if (!IsAdjacentQuads && (Opc == ISD::SRL || Opc == ISD::SRA) &&
any_of(ShiftAmt, [](auto x) { return x == 0; }))
Profitable = false;

bool IsOperandShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE;
// If operand R is not a shuffle by itself, the transformation here adds two
// shuffles, adding a non-trivial cost. Here we take out a few cases where
// the benefit is questionable according to llvm-mca's modeling.
//
// Each cell shows latency before/after transform. Here R is not a shuffle.
// SSE3
// | v16i8 | v32i8 | v64i8
// ----------------------------
// SLL | 17/17 | 20/20 | 26/26
// SRL | 18/17 | 22/20 | 35/26
// SRA | 21/19 | 26/22 | 39/30
// AVX2 using VPMUL*W
// | v16i8 | v32i8 | v64i8
// ----------------------------
// SLL | 20/18 | 18/18 | 21/21
// SRL | 20/18 | 22/18 | 26/21
// SRA | 20/20 | 22/20 | 25/23
// AVX2 using VPS*LVD
// | v16i8 | v32i8 | v64i8
// ----------------------------
// SLL | 20/16 | 18/16 | 21/20
// SRL | 20/16 | 22/16 | 26/20
// SRA | 20/18 | 22/18 | 25/22
if (!IsOperandShuffle) {
if (Subtarget.hasAVX2()) {
if (!IsAdjacentQuads || (VT == MVT::v64i8 && Opc == ISD::SHL))
Profitable = false;
} else {
if (Opc == ISD::SHL ||
((VT == MVT::v16i8 || VT == MVT::v32i8) && Opc == ISD::SRL))
Profitable = false;
}
}

// Found a permutation P that can rearrange the shift amouts into adjacent
// pair or quad of same values. Rewrite the shift S1(x) into P^-1(S2(P(x))).
if (Profitable) {
SDValue InnerShuffle = DAG.getVectorShuffle(VT, dl, R, DAG.getUNDEF(VT), Permutation);
SmallVector<SDValue, 64> NewShiftAmt;
for (int Index : Permutation) {
NewShiftAmt.push_back(Amt.getOperand(Index));
}
#ifndef NDEBUG
for (size_t I = 0; I < NewShiftAmt.size(); I += 2) {
SDValue Even = NewShiftAmt[I];
SDValue Odd = NewShiftAmt[I + 1];
assert(Even.isUndef() || Odd.isUndef() || Even->getAsZExtVal() == Odd->getAsZExtVal());
}
#endif
SDValue NewShiftVector = DAG.getBuildVector(VT, dl, NewShiftAmt);
SDValue NewShift = DAG.getNode(Opc, dl, VT, InnerShuffle, NewShiftVector);
SmallVector<int, 64> InversePermutation(Permutation.size());
for (size_t I = 0; I < Permutation.size(); ++I) {
InversePermutation[Permutation[I]] = I;
}
SDValue OuterShuffle = DAG.getVectorShuffle(VT, dl, NewShift, DAG.getUNDEF(VT), InversePermutation);
return OuterShuffle;
}
}

// If possible, lower this packed shift into a vector multiply instead of
// expanding it into a sequence of scalar shifts.
// For v32i8 cases, it might be quicker to split/extend to vXi16 shifts.
Expand Down
Loading