Skip to content

Commit a6acccd

Browse files
committed
Include support for Add/Mul/Or/And/Xor Binary Operations
1 parent 701ec74 commit a6acccd

File tree

4 files changed

+260
-84
lines changed

4 files changed

+260
-84
lines changed

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,9 @@ LLVM_ABI bool canSinkOrHoistInst(Instruction &I, AAResults *AA,
371371
/// Returns the llvm.vector.reduce intrinsic that corresponds to the recurrence
372372
/// kind.
373373
LLVM_ABI constexpr Intrinsic::ID getReductionIntrinsicID(RecurKind RK);
374+
/// Returns the llvm.vector.reduce min/max intrinsic that corresponds to the
375+
/// intrinsic op.
376+
LLVM_ABI Intrinsic::ID getMinMaxReductionIntrinsicID(Intrinsic::ID IID);
374377

375378
/// Returns the arithmetic instruction opcode used when expanding a reduction.
376379
LLVM_ABI unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID);

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,21 @@ constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
956956
}
957957
}
958958

959+
Intrinsic::ID llvm::getMinMaxReductionIntrinsicID(Intrinsic::ID IID) {
960+
switch (IID) {
961+
default:
962+
llvm_unreachable("Unexpected intrinsic id");
963+
case Intrinsic::umin:
964+
return Intrinsic::vector_reduce_umin;
965+
case Intrinsic::umax:
966+
return Intrinsic::vector_reduce_umax;
967+
case Intrinsic::smin:
968+
return Intrinsic::vector_reduce_smin;
969+
case Intrinsic::smax:
970+
return Intrinsic::vector_reduce_smax;
971+
}
972+
}
973+
959974
// This is the inverse to getReductionForBinop
960975
unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
961976
switch (RdxID) {

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 174 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -3130,21 +3130,66 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
31303130
return MadeChanges;
31313131
}
31323132

3133+
/// For a given chain of patterns of the following form:
3134+
///
3135+
/// ```
3136+
/// %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
3137+
///
3138+
/// %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
3139+
/// ty1> %1)
3140+
/// OR
3141+
/// %2 = add/mul/or/and/xor <n x ty1> %0, %1
3142+
///
3143+
/// %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
3144+
/// ...
3145+
/// ...
3146+
/// %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
3147+
/// 3), <n x ty1> %(i - 2)
3148+
/// OR
3149+
/// %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
3150+
///
3151+
/// %(i) = extractelement <n x ty1> %(i - 1), 0
3152+
/// ```
3153+
///
3154+
/// Where:
3155+
/// `mask` follows a partition pattern:
3156+
///
3157+
/// Ex:
3158+
/// [n = 8, p = poison]
3159+
///
3160+
/// 4 5 6 7 | p p p p
3161+
/// 2 3 | p p p p p p
3162+
/// 1 | p p p p p p p
3163+
///
3164+
/// For powers of 2, there's a consistent pattern, but for other cases
3165+
/// the parity of the current half value at each step decides the
3166+
/// next partition half (see `ExpectedParityMask` for more logical details
3167+
/// in generalising this).
3168+
///
3169+
/// Ex:
3170+
/// [n = 6]
3171+
///
3172+
/// 3 4 5 | p p p
3173+
/// 1 2 | p p p p
3174+
/// 1 | p p p p p
31333175
bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
3134-
auto *EEI = dyn_cast<ExtractElementInst>(&I);
3135-
if (!EEI)
3136-
return false;
3137-
3176+
// Going bottom-up for the pattern.
31383177
std::queue<Value *> InstWorklist;
3139-
Value *InitEEV = nullptr;
3140-
Intrinsic::ID CommonOp = 0;
3178+
InstructionCost OrigCost = 0;
3179+
3180+
// Common instruction operation after each shuffle op.
3181+
std::optional<unsigned int> CommonCallOp = std::nullopt;
3182+
std::optional<Instruction::BinaryOps> CommonBinOp = std::nullopt;
31413183

3142-
bool IsFirstCallInst = true;
3143-
bool ShouldBeCallInst = true;
3184+
bool IsFirstCallOrBinInst = true;
3185+
bool ShouldBeCallOrBinInst = true;
31443186

3187+
// This stores the last used instructions for shuffle/common op.
3188+
//
3189+
// PrevVecV[2] stores the first vector from extract element instruction,
3190+
// while PrevVecV[0] / PrevVecV[1] store the last two simultaneous
3191+
// instructions from either shuffle/common op.
31453192
SmallVector<Value *, 3> PrevVecV(3, nullptr);
3146-
int64_t ShuffleMaskHalf = -1, ExpectedShuffleMaskHalf = 1;
3147-
int64_t VecSize = -1;
31483193

31493194
Value *VecOp;
31503195
if (!match(&I, m_ExtractElt(m_Value(VecOp), m_Zero())))
@@ -3154,141 +3199,186 @@ bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
31543199
if (!FVT)
31553200
return false;
31563201

3157-
VecSize = FVT->getNumElements();
3158-
if (VecSize < 2 || (VecSize % 2) != 0)
3202+
int64_t VecSize = FVT->getNumElements();
3203+
if (VecSize < 2)
31593204
return false;
31603205

3161-
ShuffleMaskHalf = 1;
3162-
PrevVecV[2] = VecOp;
3163-
InitEEV = EEI;
3206+
// Number of levels would be ~log2(n), considering we always partition
3207+
// by half for this fold pattern.
3208+
unsigned int NumLevels = Log2_64_Ceil(VecSize), VisitedCnt = 0;
3209+
int64_t ShuffleMaskHalf = 1, ExpectedParityMask = 0;
3210+
3211+
// This is how we generalise for all element sizes.
3212+
// At each step, if vector size is odd, we need non-poison
3213+
// values to cover the dominant half so we don't miss out on any element.
3214+
//
3215+
// This mask will help us retrieve this as we go from bottom to top:
3216+
//
3217+
// Mask Set -> N = N * 2 - 1
3218+
// Mask Unset -> N = N * 2
3219+
for (int Cur = VecSize, Mask = NumLevels - 1; Cur > 1;
3220+
Cur = (Cur + 1) / 2, --Mask) {
3221+
if (Cur & 1)
3222+
ExpectedParityMask |= (1ll << Mask);
3223+
}
31643224

3225+
PrevVecV[2] = VecOp;
31653226
InstWorklist.push(PrevVecV[2]);
31663227

31673228
while (!InstWorklist.empty()) {
3168-
Value *V = InstWorklist.front();
3229+
Value *CI = InstWorklist.front();
31693230
InstWorklist.pop();
31703231

3171-
auto *CI = dyn_cast<Instruction>(V);
3172-
if (!CI)
3173-
return false;
3174-
3175-
if (auto *CallI = dyn_cast<CallInst>(CI)) {
3176-
if (!ShouldBeCallInst || !PrevVecV[2])
3232+
if (auto *II = dyn_cast<IntrinsicInst>(CI)) {
3233+
if (!ShouldBeCallOrBinInst)
31773234
return false;
31783235

3179-
if (!IsFirstCallInst &&
3236+
if (!IsFirstCallOrBinInst &&
31803237
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
31813238
return false;
31823239

3183-
if (CallI != (IsFirstCallInst ? PrevVecV[2] : PrevVecV[0]))
3240+
// For the first found call/bin op, the vector has to come from the
3241+
// extract element op.
3242+
if (II != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
31843243
return false;
3185-
IsFirstCallInst = false;
3244+
IsFirstCallOrBinInst = false;
31863245

3187-
auto *II = dyn_cast<IntrinsicInst>(CallI);
3188-
if (!II)
3189-
return false;
3190-
3191-
if (!CommonOp)
3192-
CommonOp = II->getIntrinsicID();
3193-
if (II->getIntrinsicID() != CommonOp)
3246+
if (!CommonCallOp)
3247+
CommonCallOp = II->getIntrinsicID();
3248+
if (II->getIntrinsicID() != *CommonCallOp)
31943249
return false;
31953250

31963251
switch (II->getIntrinsicID()) {
31973252
case Intrinsic::umin:
31983253
case Intrinsic::umax:
31993254
case Intrinsic::smin:
32003255
case Intrinsic::smax: {
3201-
auto *Op0 = CallI->getOperand(0);
3202-
auto *Op1 = CallI->getOperand(1);
3256+
auto *Op0 = II->getOperand(0);
3257+
auto *Op1 = II->getOperand(1);
32033258
PrevVecV[0] = Op0;
32043259
PrevVecV[1] = Op1;
32053260
break;
32063261
}
32073262
default:
32083263
return false;
32093264
}
3210-
ShouldBeCallInst ^= 1;
3265+
ShouldBeCallOrBinInst ^= 1;
3266+
3267+
IntrinsicCostAttributes ICA(
3268+
*CommonCallOp, II->getType(),
3269+
{PrevVecV[0]->getType(), PrevVecV[1]->getType()});
3270+
OrigCost += TTI.getIntrinsicInstrCost(ICA, CostKind);
32113271

3272+
// We may need a swap here since it can be (a, b) or (b, a)
3273+
// and accordingly change as we go up.
32123274
if (!isa<ShuffleVectorInst>(PrevVecV[1]))
32133275
std::swap(PrevVecV[0], PrevVecV[1]);
32143276
InstWorklist.push(PrevVecV[1]);
32153277
InstWorklist.push(PrevVecV[0]);
3216-
} else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
3217-
if (ShouldBeCallInst ||
3278+
} else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) {
3279+
// Similar logic for bin ops.
3280+
3281+
if (!ShouldBeCallOrBinInst)
3282+
return false;
3283+
3284+
if (!IsFirstCallOrBinInst &&
32183285
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
32193286
return false;
32203287

3221-
if (SVInst != PrevVecV[1])
3288+
if (BinOp != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
3289+
return false;
3290+
IsFirstCallOrBinInst = false;
3291+
3292+
if (!CommonBinOp)
3293+
CommonBinOp = BinOp->getOpcode();
3294+
3295+
if (BinOp->getOpcode() != *CommonBinOp)
32223296
return false;
32233297

3224-
auto *ShuffleVec = SVInst->getOperand(0);
3225-
if (!ShuffleVec || ShuffleVec != PrevVecV[0])
3298+
switch (*CommonBinOp) {
3299+
case BinaryOperator::Add:
3300+
case BinaryOperator::Mul:
3301+
case BinaryOperator::Or:
3302+
case BinaryOperator::And:
3303+
case BinaryOperator::Xor: {
3304+
auto *Op0 = BinOp->getOperand(0);
3305+
auto *Op1 = BinOp->getOperand(1);
3306+
PrevVecV[0] = Op0;
3307+
PrevVecV[1] = Op1;
3308+
break;
3309+
}
3310+
default:
32263311
return false;
3312+
}
3313+
ShouldBeCallOrBinInst ^= 1;
32273314

3228-
SmallVector<int> CurMask;
3229-
SVInst->getShuffleMask(CurMask);
3315+
OrigCost +=
3316+
TTI.getArithmeticInstrCost(*CommonBinOp, BinOp->getType(), CostKind);
32303317

3231-
if (ShuffleMaskHalf != ExpectedShuffleMaskHalf)
3318+
if (!isa<ShuffleVectorInst>(PrevVecV[1]))
3319+
std::swap(PrevVecV[0], PrevVecV[1]);
3320+
InstWorklist.push(PrevVecV[1]);
3321+
InstWorklist.push(PrevVecV[0]);
3322+
} else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
3323+
// We shouldn't have any null values in the previous vectors,
3324+
// is so, there was a mismatch in pattern.
3325+
if (ShouldBeCallOrBinInst ||
3326+
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
3327+
return false;
3328+
3329+
if (SVInst != PrevVecV[1])
3330+
return false;
3331+
3332+
ArrayRef<int> CurMask;
3333+
if (!match(SVInst, m_Shuffle(m_Specific(PrevVecV[0]), m_Poison(),
3334+
m_Mask(CurMask))))
32323335
return false;
3233-
ExpectedShuffleMaskHalf *= 2;
32343336

3337+
// Subtract the parity mask when checking the condition.
32353338
for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) {
3236-
if (Mask < ShuffleMaskHalf && CurMask[Mask] != ShuffleMaskHalf + Mask)
3339+
if (Mask < ShuffleMaskHalf &&
3340+
CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1))
32373341
return false;
32383342
if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1)
32393343
return false;
32403344
}
3345+
3346+
// Update mask values.
32413347
ShuffleMaskHalf *= 2;
3242-
if (ExpectedShuffleMaskHalf == VecSize)
3348+
ShuffleMaskHalf -= (ExpectedParityMask & 1);
3349+
ExpectedParityMask >>= 1;
3350+
3351+
OrigCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
3352+
SVInst->getType(), SVInst->getType(),
3353+
CurMask, CostKind);
3354+
3355+
VisitedCnt += 1;
3356+
if (!ExpectedParityMask && VisitedCnt == NumLevels)
32433357
break;
3244-
ShouldBeCallInst ^= 1;
3358+
3359+
ShouldBeCallOrBinInst ^= 1;
32453360
} else {
32463361
return false;
32473362
}
32483363
}
32493364

3250-
if (ShouldBeCallInst)
3365+
// Pattern should end with a shuffle op.
3366+
if (ShouldBeCallOrBinInst)
32513367
return false;
32523368

3253-
assert(VecSize != -1 && ExpectedShuffleMaskHalf == VecSize &&
3254-
"Expected Match for Vector Size and Mask Half");
3369+
assert(VecSize != -1 && "Expected Match for Vector Size");
32553370

32563371
Value *FinalVecV = PrevVecV[0];
3257-
auto *FinalVecVTy = dyn_cast<FixedVectorType>(FinalVecV->getType());
3258-
3259-
if (!InitEEV || !FinalVecV)
3372+
if (!FinalVecV)
32603373
return false;
32613374

3262-
assert(FinalVecVTy && "Expected non-null value for Vector Type");
3375+
auto *FinalVecVTy = cast<FixedVectorType>(FinalVecV->getType());
32633376

3264-
Intrinsic::ID ReducedOp = 0;
3265-
switch (CommonOp) {
3266-
case Intrinsic::umin:
3267-
ReducedOp = Intrinsic::vector_reduce_umin;
3268-
break;
3269-
case Intrinsic::umax:
3270-
ReducedOp = Intrinsic::vector_reduce_umax;
3271-
break;
3272-
case Intrinsic::smin:
3273-
ReducedOp = Intrinsic::vector_reduce_smin;
3274-
break;
3275-
case Intrinsic::smax:
3276-
ReducedOp = Intrinsic::vector_reduce_smax;
3277-
break;
3278-
default:
3377+
Intrinsic::ID ReducedOp =
3378+
(CommonCallOp ? getMinMaxReductionIntrinsicID(*CommonCallOp)
3379+
: getReductionForBinop(*CommonBinOp));
3380+
if (!ReducedOp)
32793381
return false;
3280-
}
3281-
3282-
InstructionCost OrigCost = 0;
3283-
unsigned int NumLevels = Log2_64(VecSize);
3284-
3285-
for (unsigned int Level = 0; Level < NumLevels; ++Level) {
3286-
OrigCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
3287-
FinalVecVTy, FinalVecVTy);
3288-
OrigCost += TTI.getArithmeticInstrCost(Instruction::ICmp, FinalVecVTy);
3289-
}
3290-
OrigCost += TTI.getVectorInstrCost(Instruction::ExtractElement, FinalVecVTy,
3291-
CostKind, 0);
32923382

32933383
IntrinsicCostAttributes ICA(ReducedOp, FinalVecVTy, {FinalVecV});
32943384
InstructionCost NewCost = TTI.getIntrinsicInstrCost(ICA, CostKind);
@@ -3298,7 +3388,7 @@ bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
32983388

32993389
auto *ReducedResult =
33003390
Builder.CreateIntrinsic(ReducedOp, {FinalVecV->getType()}, {FinalVecV});
3301-
replaceValue(*InitEEV, *ReducedResult);
3391+
replaceValue(I, *ReducedResult);
33023392

33033393
return true;
33043394
}
@@ -4391,8 +4481,8 @@ bool VectorCombine::run() {
43914481
return true;
43924482
break;
43934483
case Instruction::ExtractElement:
4394-
MadeChange |= foldShuffleChainsToReduce(I);
4395-
break;
4484+
if (foldShuffleChainsToReduce(I))
4485+
return true;
43964486
case Instruction::ICmp:
43974487
case Instruction::FCmp:
43984488
if (foldExtractExtract(I))

0 commit comments

Comments
 (0)