Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/include/llvm/Transforms/Utils/LoopUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ LLVM_ABI bool canSinkOrHoistInst(Instruction &I, AAResults *AA,
/// Returns the llvm.vector.reduce intrinsic that corresponds to the recurrence
/// kind.
LLVM_ABI constexpr Intrinsic::ID getReductionIntrinsicID(RecurKind RK);
/// Returns the llvm.vector.reduce min/max intrinsic that corresponds to the
/// intrinsic op.
LLVM_ABI Intrinsic::ID getMinMaxReductionIntrinsicID(Intrinsic::ID IID);

/// Returns the arithmetic instruction opcode used when expanding a reduction.
LLVM_ABI unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID);
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/Transforms/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,21 @@ constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
}
}

Intrinsic::ID llvm::getMinMaxReductionIntrinsicID(Intrinsic::ID IID) {
switch (IID) {
default:
llvm_unreachable("Unexpected intrinsic id");
case Intrinsic::umin:
return Intrinsic::vector_reduce_umin;
case Intrinsic::umax:
return Intrinsic::vector_reduce_umax;
case Intrinsic::smin:
return Intrinsic::vector_reduce_smin;
case Intrinsic::smax:
return Intrinsic::vector_reduce_smax;
}
}

// This is the inverse to getReductionForBinop
unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
switch (RdxID) {
Expand Down
267 changes: 267 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class VectorCombine {
bool foldShuffleOfIntrinsics(Instruction &I);
bool foldShuffleToIdentity(Instruction &I);
bool foldShuffleFromReductions(Instruction &I);
bool foldShuffleChainsToReduce(Instruction &I);
bool foldCastFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
bool foldInterleaveIntrinsics(Instruction &I);
Expand Down Expand Up @@ -3129,6 +3130,269 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
return MadeChanges;
}

/// For a given chain of patterns of the following form:
///
/// ```
/// %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
///
/// %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
/// ty1> %1)
/// OR
/// %2 = add/mul/or/and/xor <n x ty1> %0, %1
///
/// %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
/// ...
/// ...
/// %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
/// 3), <n x ty1> %(i - 2)
/// OR
/// %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
///
/// %(i) = extractelement <n x ty1> %(i - 1), 0
/// ```
///
/// Where:
/// `mask` follows a partition pattern:
///
/// Ex:
/// [n = 8, p = poison]
///
/// 4 5 6 7 | p p p p
/// 2 3 | p p p p p p
/// 1 | p p p p p p p
///
/// For powers of 2, there's a consistent pattern, but for other cases
/// the parity of the current half value at each step decides the
/// next partition half (see `ExpectedParityMask` for more logical details
/// in generalising this).
///
/// Ex:
/// [n = 6]
///
/// 3 4 5 | p p p
/// 1 2 | p p p p
/// 1 | p p p p p
bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
// Going bottom-up for the pattern.
std::queue<Value *> InstWorklist;
InstructionCost OrigCost = 0;

// Common instruction operation after each shuffle op.
std::optional<unsigned int> CommonCallOp = std::nullopt;
std::optional<Instruction::BinaryOps> CommonBinOp = std::nullopt;

bool IsFirstCallOrBinInst = true;
bool ShouldBeCallOrBinInst = true;

// This stores the last used instructions for shuffle/common op.
//
// PrevVecV[2] stores the first vector from extract element instruction,
// while PrevVecV[0] / PrevVecV[1] store the last two simultaneous
// instructions from either shuffle/common op.
SmallVector<Value *, 3> PrevVecV(3, nullptr);

Value *VecOp;
if (!match(&I, m_ExtractElt(m_Value(VecOp), m_Zero())))
return false;

auto *FVT = dyn_cast<FixedVectorType>(VecOp->getType());
if (!FVT)
return false;

int64_t VecSize = FVT->getNumElements();
if (VecSize < 2)
return false;

// Number of levels would be ~log2(n), considering we always partition
// by half for this fold pattern.
unsigned int NumLevels = Log2_64_Ceil(VecSize), VisitedCnt = 0;
int64_t ShuffleMaskHalf = 1, ExpectedParityMask = 0;

// This is how we generalise for all element sizes.
// At each step, if vector size is odd, we need non-poison
// values to cover the dominant half so we don't miss out on any element.
//
// This mask will help us retrieve this as we go from bottom to top:
//
// Mask Set -> N = N * 2 - 1
// Mask Unset -> N = N * 2
for (int Cur = VecSize, Mask = NumLevels - 1; Cur > 1;
Cur = (Cur + 1) / 2, --Mask) {
if (Cur & 1)
ExpectedParityMask |= (1ll << Mask);
}

PrevVecV[2] = VecOp;
InstWorklist.push(PrevVecV[2]);

while (!InstWorklist.empty()) {
Value *CI = InstWorklist.front();
InstWorklist.pop();

if (auto *II = dyn_cast<IntrinsicInst>(CI)) {
if (!ShouldBeCallOrBinInst)
return false;

if (!IsFirstCallOrBinInst &&
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
return false;

// For the first found call/bin op, the vector has to come from the
// extract element op.
if (II != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
return false;
IsFirstCallOrBinInst = false;

if (!CommonCallOp)
CommonCallOp = II->getIntrinsicID();
if (II->getIntrinsicID() != *CommonCallOp)
return false;

switch (II->getIntrinsicID()) {
case Intrinsic::umin:
case Intrinsic::umax:
case Intrinsic::smin:
case Intrinsic::smax: {
auto *Op0 = II->getOperand(0);
auto *Op1 = II->getOperand(1);
PrevVecV[0] = Op0;
PrevVecV[1] = Op1;
break;
}
default:
return false;
}
ShouldBeCallOrBinInst ^= 1;

IntrinsicCostAttributes ICA(
*CommonCallOp, II->getType(),
{PrevVecV[0]->getType(), PrevVecV[1]->getType()});
OrigCost += TTI.getIntrinsicInstrCost(ICA, CostKind);

// We may need a swap here since it can be (a, b) or (b, a)
// and accordingly change as we go up.
if (!isa<ShuffleVectorInst>(PrevVecV[1]))
std::swap(PrevVecV[0], PrevVecV[1]);
InstWorklist.push(PrevVecV[1]);
InstWorklist.push(PrevVecV[0]);
} else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) {
// Similar logic for bin ops.

if (!ShouldBeCallOrBinInst)
return false;

if (!IsFirstCallOrBinInst &&
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
return false;

if (BinOp != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
return false;
IsFirstCallOrBinInst = false;

if (!CommonBinOp)
CommonBinOp = BinOp->getOpcode();

if (BinOp->getOpcode() != *CommonBinOp)
return false;

switch (*CommonBinOp) {
case BinaryOperator::Add:
case BinaryOperator::Mul:
case BinaryOperator::Or:
case BinaryOperator::And:
case BinaryOperator::Xor: {
auto *Op0 = BinOp->getOperand(0);
auto *Op1 = BinOp->getOperand(1);
PrevVecV[0] = Op0;
PrevVecV[1] = Op1;
break;
}
default:
return false;
}
ShouldBeCallOrBinInst ^= 1;

OrigCost +=
TTI.getArithmeticInstrCost(*CommonBinOp, BinOp->getType(), CostKind);

if (!isa<ShuffleVectorInst>(PrevVecV[1]))
std::swap(PrevVecV[0], PrevVecV[1]);
InstWorklist.push(PrevVecV[1]);
InstWorklist.push(PrevVecV[0]);
} else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
// We shouldn't have any null values in the previous vectors,
// is so, there was a mismatch in pattern.
if (ShouldBeCallOrBinInst ||
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
return false;

if (SVInst != PrevVecV[1])
return false;

ArrayRef<int> CurMask;
if (!match(SVInst, m_Shuffle(m_Specific(PrevVecV[0]), m_Poison(),
m_Mask(CurMask))))
return false;

// Subtract the parity mask when checking the condition.
for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) {
if (Mask < ShuffleMaskHalf &&
CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1))
return false;
if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1)
return false;
}

// Update mask values.
ShuffleMaskHalf *= 2;
ShuffleMaskHalf -= (ExpectedParityMask & 1);
ExpectedParityMask >>= 1;

OrigCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
SVInst->getType(), SVInst->getType(),
CurMask, CostKind);

VisitedCnt += 1;
if (!ExpectedParityMask && VisitedCnt == NumLevels)
break;

ShouldBeCallOrBinInst ^= 1;
} else {
return false;
}
}

// Pattern should end with a shuffle op.
if (ShouldBeCallOrBinInst)
return false;

assert(VecSize != -1 && "Expected Match for Vector Size");

Value *FinalVecV = PrevVecV[0];
if (!FinalVecV)
return false;

auto *FinalVecVTy = cast<FixedVectorType>(FinalVecV->getType());

Intrinsic::ID ReducedOp =
(CommonCallOp ? getMinMaxReductionIntrinsicID(*CommonCallOp)
: getReductionForBinop(*CommonBinOp));
if (!ReducedOp)
return false;

IntrinsicCostAttributes ICA(ReducedOp, FinalVecVTy, {FinalVecV});
InstructionCost NewCost = TTI.getIntrinsicInstrCost(ICA, CostKind);

if (NewCost >= OrigCost)
return false;

auto *ReducedResult =
Builder.CreateIntrinsic(ReducedOp, {FinalVecV->getType()}, {FinalVecV});
replaceValue(I, *ReducedResult);

return true;
}

/// Determine if its more efficient to fold:
/// reduce(trunc(x)) -> trunc(reduce(x)).
/// reduce(sext(x)) -> sext(reduce(x)).
Expand Down Expand Up @@ -4216,6 +4480,9 @@ bool VectorCombine::run() {
if (foldCastFromReductions(I))
return true;
break;
case Instruction::ExtractElement:
if (foldShuffleChainsToReduce(I))
return true;
case Instruction::ICmp:
case Instruction::FCmp:
if (foldExtractExtract(I))
Expand Down
Loading