@@ -1377,6 +1377,18 @@ class BoUpSLP {
13771377 return MinBWs.at(VectorizableTree.front().get()).second;
13781378 }
13791379
1380+ /// Returns reduction bitwidth and signedness, if it does not match the
1381+ /// original requested size.
1382+ std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
1383+ if (ReductionBitWidth == 0 ||
1384+ ReductionBitWidth >=
1385+ DL->getTypeSizeInBits(
1386+ VectorizableTree.front()->Scalars.front()->getType()))
1387+ return std::nullopt;
1388+ return std::make_pair(ReductionBitWidth,
1389+ MinBWs.at(VectorizableTree.front().get()).second);
1390+ }
1391+
13801392 /// Builds external uses of the vectorized scalars, i.e. the list of
13811393 /// vectorized scalars to be extracted, their lanes and their scalar users. \p
13821394 /// ExternallyUsedValues contains additional list of external uses to handle
@@ -17916,24 +17928,37 @@ void BoUpSLP::computeMinimumValueSizes() {
1791617928 // Add reduction ops sizes, if any.
1791717929 if (UserIgnoreList &&
1791817930 isa<IntegerType>(VectorizableTree.front()->Scalars.front()->getType())) {
17919- for (Value *V : *UserIgnoreList) {
17920- auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
17921- auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
17922- unsigned BitWidth1 = NumTypeBits - NumSignBits;
17923- if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
17924- ++BitWidth1;
17925- unsigned BitWidth2 = BitWidth1;
17926- if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
17927- auto Mask = DB->getDemandedBits(cast<Instruction>(V));
17928- BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
17931+ // Convert vector_reduce_add(ZExt(<n x i1>)) to ZExtOrTrunc(ctpop(bitcast <n
17932+ // x i1> to in)).
17933+ if (all_of(*UserIgnoreList,
17934+ [](Value *V) {
17935+ return cast<Instruction>(V)->getOpcode() == Instruction::Add;
17936+ }) &&
17937+ VectorizableTree.front()->State == TreeEntry::Vectorize &&
17938+ VectorizableTree.front()->getOpcode() == Instruction::ZExt &&
17939+ cast<CastInst>(VectorizableTree.front()->getMainOp())->getSrcTy() ==
17940+ Builder.getInt1Ty()) {
17941+ ReductionBitWidth = 1;
17942+ } else {
17943+ for (Value *V : *UserIgnoreList) {
17944+ unsigned NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
17945+ TypeSize NumTypeBits = DL->getTypeSizeInBits(V->getType());
17946+ unsigned BitWidth1 = NumTypeBits - NumSignBits;
17947+ if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
17948+ ++BitWidth1;
17949+ unsigned BitWidth2 = BitWidth1;
17950+ if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
17951+ APInt Mask = DB->getDemandedBits(cast<Instruction>(V));
17952+ BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
17953+ }
17954+ ReductionBitWidth =
17955+ std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
1792917956 }
17930- ReductionBitWidth =
17931- std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
17932- }
17933- if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
17934- ReductionBitWidth = 8;
17957+ if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
17958+ ReductionBitWidth = 8;
1793517959
17936- ReductionBitWidth = bit_ceil(ReductionBitWidth);
17960+ ReductionBitWidth = bit_ceil(ReductionBitWidth);
17961+ }
1793717962 }
1793817963 bool IsTopRoot = NodeIdx == 0;
1793917964 while (NodeIdx < VectorizableTree.size() &&
@@ -19789,8 +19814,8 @@ class HorizontalReduction {
1978919814
1979019815 // Estimate cost.
1979119816 InstructionCost TreeCost = V.getTreeCost(VL);
19792- InstructionCost ReductionCost =
19793- getReductionCost( TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF );
19817+ InstructionCost ReductionCost = getReductionCost(
19818+ TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign() );
1979419819 InstructionCost Cost = TreeCost + ReductionCost;
1979519820 LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
1979619821 << " for reduction\n");
@@ -19895,10 +19920,12 @@ class HorizontalReduction {
1989519920 createStrideMask(I, ScalarTyNumElements, VL.size());
1989619921 Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
1989719922 ReducedSubTree = Builder.CreateInsertElement(
19898- ReducedSubTree, emitReduction(Lane, Builder, TTI), I);
19923+ ReducedSubTree,
19924+ emitReduction(Lane, Builder, TTI, RdxRootInst->getType()), I);
1989919925 }
1990019926 } else {
19901- ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI);
19927+ ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI,
19928+ RdxRootInst->getType());
1990219929 }
1990319930 if (ReducedSubTree->getType() != VL.front()->getType()) {
1990419931 assert(ReducedSubTree->getType() != VL.front()->getType() &&
@@ -20079,12 +20106,13 @@ class HorizontalReduction {
2007920106
2008020107private:
2008120108 /// Calculate the cost of a reduction.
20082- InstructionCost getReductionCost(TargetTransformInfo *TTI,
20083- ArrayRef<Value *> ReducedVals,
20084- bool IsCmpSelMinMax, unsigned ReduxWidth ,
20085- FastMathFlags FMF ) {
20109+ InstructionCost getReductionCost(
20110+ TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
20111+ bool IsCmpSelMinMax, FastMathFlags FMF ,
20112+ const std::optional<std::pair<unsigned, bool>> BitwidthAndSign ) {
2008620113 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2008720114 Type *ScalarTy = ReducedVals.front()->getType();
20115+ unsigned ReduxWidth = ReducedVals.size();
2008820116 FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
2008920117 InstructionCost VectorCost = 0, ScalarCost;
2009020118 // If all of the reduced values are constant, the vector cost is 0, since
@@ -20143,8 +20171,22 @@ class HorizontalReduction {
2014320171 VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
2014420172 /*Extract*/ false, TTI::TCK_RecipThroughput);
2014520173 } else {
20146- VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF,
20147- CostKind);
20174+ auto [Bitwidth, IsSigned] =
20175+ BitwidthAndSign.value_or(std::make_pair(0u, false));
20176+ if (RdxKind == RecurKind::Add && Bitwidth == 1) {
20177+ // Represent vector_reduce_add(ZExt(<n x i1>)) to
20178+ // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20179+ auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
20180+ IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
20181+ VectorCost =
20182+ TTI->getCastInstrCost(Instruction::BitCast, IntTy,
20183+ getWidenedType(ScalarTy, ReduxWidth),
20184+ TTI::CastContextHint::None, CostKind) +
20185+ TTI->getIntrinsicInstrCost(ICA, CostKind);
20186+ } else {
20187+ VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
20188+ FMF, CostKind);
20189+ }
2014820190 }
2014920191 }
2015020192 ScalarCost = EvaluateScalarCost([&]() {
@@ -20181,11 +20223,22 @@ class HorizontalReduction {
2018120223
2018220224 /// Emit a horizontal reduction of the vectorized value.
2018320225 Value *emitReduction(Value *VectorizedValue, IRBuilderBase &Builder,
20184- const TargetTransformInfo *TTI) {
20226+ const TargetTransformInfo *TTI, Type *DestTy ) {
2018520227 assert(VectorizedValue && "Need to have a vectorized tree node");
2018620228 assert(RdxKind != RecurKind::FMulAdd &&
2018720229 "A call to the llvm.fmuladd intrinsic is not handled yet");
2018820230
20231+ auto *FTy = cast<FixedVectorType>(VectorizedValue->getType());
20232+ if (FTy->getScalarType() == Builder.getInt1Ty() &&
20233+ RdxKind == RecurKind::Add &&
20234+ DestTy->getScalarType() != FTy->getScalarType()) {
20235+ // Convert vector_reduce_add(ZExt(<n x i1>)) to
20236+ // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20237+ Value *V = Builder.CreateBitCast(
20238+ VectorizedValue, Builder.getIntNTy(FTy->getNumElements()));
20239+ ++NumVectorInstructions;
20240+ return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V);
20241+ }
2018920242 ++NumVectorInstructions;
2019020243 return createSimpleReduction(Builder, VectorizedValue, RdxKind);
2019120244 }
0 commit comments