Skip to content

Commit 75a6cd4

Browse files
committed
[𝘀𝗽𝗿] initial version
Created using spr 1.3.5
1 parent 79682c4 commit 75a6cd4

File tree

3 files changed

+86
-31
lines changed

3 files changed

+86
-31
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,6 +1371,18 @@ class BoUpSLP {
13711371
return MinBWs.at(VectorizableTree.front().get()).second;
13721372
}
13731373

1374+
/// Returns reduction bitwidth and signedness, if it does not match the
1375+
/// original requested size.
1376+
std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
1377+
if (ReductionBitWidth == 0 ||
1378+
ReductionBitWidth ==
1379+
DL->getTypeSizeInBits(
1380+
VectorizableTree.front()->Scalars.front()->getType()))
1381+
return std::nullopt;
1382+
return std::make_pair(ReductionBitWidth,
1383+
MinBWs.at(VectorizableTree.front().get()).second);
1384+
}
1385+
13741386
/// Builds external uses of the vectorized scalars, i.e. the list of
13751387
/// vectorized scalars to be extracted, their lanes and their scalar users. \p
13761388
/// ExternallyUsedValues contains additional list of external uses to handle
@@ -17885,24 +17897,37 @@ void BoUpSLP::computeMinimumValueSizes() {
1788517897
// Add reduction ops sizes, if any.
1788617898
if (UserIgnoreList &&
1788717899
isa<IntegerType>(VectorizableTree.front()->Scalars.front()->getType())) {
17888-
for (Value *V : *UserIgnoreList) {
17889-
auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
17890-
auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
17891-
unsigned BitWidth1 = NumTypeBits - NumSignBits;
17892-
if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
17893-
++BitWidth1;
17894-
unsigned BitWidth2 = BitWidth1;
17895-
if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
17896-
auto Mask = DB->getDemandedBits(cast<Instruction>(V));
17897-
BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
17900+
// Convert vector_reduce_add(ZExt(<n x i1>)) to
17901+
// ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
17902+
if (all_of(*UserIgnoreList,
17903+
[](Value *V) {
17904+
return cast<Instruction>(V)->getOpcode() == Instruction::Add;
17905+
}) &&
17906+
VectorizableTree.front()->State == TreeEntry::Vectorize &&
17907+
VectorizableTree.front()->getOpcode() == Instruction::ZExt &&
17908+
cast<CastInst>(VectorizableTree.front()->getMainOp())->getSrcTy() ==
17909+
Builder.getInt1Ty()) {
17910+
ReductionBitWidth = 1;
17911+
} else {
17912+
for (Value *V : *UserIgnoreList) {
17913+
auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
17914+
auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
17915+
unsigned BitWidth1 = NumTypeBits - NumSignBits;
17916+
if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
17917+
++BitWidth1;
17918+
unsigned BitWidth2 = BitWidth1;
17919+
if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
17920+
auto Mask = DB->getDemandedBits(cast<Instruction>(V));
17921+
BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
17922+
}
17923+
ReductionBitWidth =
17924+
std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
1789817925
}
17899-
ReductionBitWidth =
17900-
std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
17901-
}
17902-
if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
17903-
ReductionBitWidth = 8;
17926+
if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
17927+
ReductionBitWidth = 8;
1790417928

17905-
ReductionBitWidth = bit_ceil(ReductionBitWidth);
17929+
ReductionBitWidth = bit_ceil(ReductionBitWidth);
17930+
}
1790617931
}
1790717932
bool IsTopRoot = NodeIdx == 0;
1790817933
while (NodeIdx < VectorizableTree.size() &&
@@ -19758,8 +19783,8 @@ class HorizontalReduction {
1975819783

1975919784
// Estimate cost.
1976019785
InstructionCost TreeCost = V.getTreeCost(VL);
19761-
InstructionCost ReductionCost =
19762-
getReductionCost(TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF);
19786+
InstructionCost ReductionCost = getReductionCost(
19787+
TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign());
1976319788
InstructionCost Cost = TreeCost + ReductionCost;
1976419789
LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
1976519790
<< " for reduction\n");
@@ -19864,10 +19889,12 @@ class HorizontalReduction {
1986419889
createStrideMask(I, ScalarTyNumElements, VL.size());
1986519890
Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
1986619891
ReducedSubTree = Builder.CreateInsertElement(
19867-
ReducedSubTree, emitReduction(Lane, Builder, TTI), I);
19892+
ReducedSubTree,
19893+
emitReduction(Lane, Builder, TTI, RdxRootInst->getType()), I);
1986819894
}
1986919895
} else {
19870-
ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI);
19896+
ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI,
19897+
RdxRootInst->getType());
1987119898
}
1987219899
if (ReducedSubTree->getType() != VL.front()->getType()) {
1987319900
assert(ReducedSubTree->getType() != VL.front()->getType() &&
@@ -20048,12 +20075,13 @@ class HorizontalReduction {
2004820075

2004920076
private:
2005020077
/// Calculate the cost of a reduction.
20051-
InstructionCost getReductionCost(TargetTransformInfo *TTI,
20052-
ArrayRef<Value *> ReducedVals,
20053-
bool IsCmpSelMinMax, unsigned ReduxWidth,
20054-
FastMathFlags FMF) {
20078+
InstructionCost getReductionCost(
20079+
TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
20080+
bool IsCmpSelMinMax, FastMathFlags FMF,
20081+
const std::optional<std::pair<unsigned, bool>> BitwidthAndSign) {
2005520082
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2005620083
Type *ScalarTy = ReducedVals.front()->getType();
20084+
unsigned ReduxWidth = ReducedVals.size();
2005720085
FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
2005820086
InstructionCost VectorCost = 0, ScalarCost;
2005920087
// If all of the reduced values are constant, the vector cost is 0, since
@@ -20112,8 +20140,22 @@ class HorizontalReduction {
2011220140
VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
2011320141
/*Extract*/ false, TTI::TCK_RecipThroughput);
2011420142
} else {
20115-
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF,
20116-
CostKind);
20143+
auto [Bitwidth, IsSigned] =
20144+
BitwidthAndSign.value_or(std::make_pair(0u, false));
20145+
if (RdxKind == RecurKind::Add && Bitwidth == 1) {
20146+
// Represent vector_reduce_add(ZExt(<n x i1>)) to
20147+
// ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20148+
auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
20149+
IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
20150+
VectorCost =
20151+
TTI->getCastInstrCost(Instruction::BitCast, IntTy,
20152+
getWidenedType(ScalarTy, ReduxWidth),
20153+
TTI::CastContextHint::None, CostKind) +
20154+
TTI->getIntrinsicInstrCost(ICA, CostKind);
20155+
} else {
20156+
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
20157+
FMF, CostKind);
20158+
}
2011720159
}
2011820160
}
2011920161
ScalarCost = EvaluateScalarCost([&]() {
@@ -20150,11 +20192,22 @@ class HorizontalReduction {
2015020192

2015120193
/// Emit a horizontal reduction of the vectorized value.
2015220194
Value *emitReduction(Value *VectorizedValue, IRBuilderBase &Builder,
20153-
const TargetTransformInfo *TTI) {
20195+
const TargetTransformInfo *TTI, Type *DestTy) {
2015420196
assert(VectorizedValue && "Need to have a vectorized tree node");
2015520197
assert(RdxKind != RecurKind::FMulAdd &&
2015620198
"A call to the llvm.fmuladd intrinsic is not handled yet");
2015720199

20200+
auto *FTy = cast<FixedVectorType>(VectorizedValue->getType());
20201+
if (FTy->getScalarType() == Builder.getInt1Ty() &&
20202+
RdxKind == RecurKind::Add &&
20203+
DestTy->getScalarType() != FTy->getScalarType()) {
20204+
// Convert vector_reduce_add(ZExt(<n x i1>)) to
20205+
// ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20206+
Value *V = Builder.CreateBitCast(
20207+
VectorizedValue, Builder.getIntNTy(FTy->getNumElements()));
20208+
++NumVectorInstructions;
20209+
return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V);
20210+
}
2015820211
++NumVectorInstructions;
2015920212
return createSimpleReduction(Builder, VectorizedValue, RdxKind);
2016020213
}

llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ define i16 @test(i16 %call37) {
1111
; CHECK-NEXT: [[TMP2:%.*]] = icmp slt <8 x i16> [[SHUFFLE]], zeroinitializer
1212
; CHECK-NEXT: [[TMP3:%.*]] = icmp sgt <8 x i16> [[SHUFFLE]], zeroinitializer
1313
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <8 x i1> [[TMP2]], <8 x i1> [[TMP3]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 12, i32 5, i32 6, i32 7>
14-
; CHECK-NEXT: [[TMP5:%.*]] = zext <8 x i1> [[TMP4]] to <8 x i16>
15-
; CHECK-NEXT: [[TMP6:%.*]] = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> [[TMP5]])
14+
; CHECK-NEXT: [[TMP8:%.*]] = bitcast <8 x i1> [[TMP4]] to i8
15+
; CHECK-NEXT: [[TMP7:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP8]])
16+
; CHECK-NEXT: [[TMP6:%.*]] = zext i8 [[TMP7]] to i16
1617
; CHECK-NEXT: [[OP_RDX:%.*]] = add i16 [[TMP6]], 0
1718
; CHECK-NEXT: ret i16 [[OP_RDX]]
1819
;

llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ define i32 @test(i32 %a, i8 %b, i8 %c) {
1414
; CHECK-NEXT: [[TMP8:%.*]] = zext <4 x i8> [[TMP2]] to <4 x i16>
1515
; CHECK-NEXT: [[TMP9:%.*]] = sext <4 x i8> [[TMP4]] to <4 x i16>
1616
; CHECK-NEXT: [[TMP5:%.*]] = icmp sle <4 x i16> [[TMP8]], [[TMP9]]
17-
; CHECK-NEXT: [[TMP6:%.*]] = zext <4 x i1> [[TMP5]] to <4 x i32>
18-
; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
17+
; CHECK-NEXT: [[TMP10:%.*]] = bitcast <4 x i1> [[TMP5]] to i4
18+
; CHECK-NEXT: [[TMP11:%.*]] = call i4 @llvm.ctpop.i4(i4 [[TMP10]])
19+
; CHECK-NEXT: [[TMP7:%.*]] = zext i4 [[TMP11]] to i32
1920
; CHECK-NEXT: [[OP_RDX:%.*]] = add i32 [[TMP7]], [[A]]
2021
; CHECK-NEXT: ret i32 [[OP_RDX]]
2122
;

0 commit comments

Comments
 (0)