@@ -1371,6 +1371,18 @@ class BoUpSLP {
13711371 return MinBWs.at(VectorizableTree.front().get()).second;
13721372 }
13731373
1374+ /// Returns reduction bitwidth and signedness, if it does not match the
1375+ /// original requested size.
1376+ std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
1377+ if (ReductionBitWidth == 0 ||
1378+ ReductionBitWidth ==
1379+ DL->getTypeSizeInBits(
1380+ VectorizableTree.front()->Scalars.front()->getType()))
1381+ return std::nullopt;
1382+ return std::make_pair(ReductionBitWidth,
1383+ MinBWs.at(VectorizableTree.front().get()).second);
1384+ }
1385+
13741386 /// Builds external uses of the vectorized scalars, i.e. the list of
13751387 /// vectorized scalars to be extracted, their lanes and their scalar users. \p
13761388 /// ExternallyUsedValues contains additional list of external uses to handle
@@ -17885,24 +17897,37 @@ void BoUpSLP::computeMinimumValueSizes() {
1788517897 // Add reduction ops sizes, if any.
1788617898 if (UserIgnoreList &&
1788717899 isa<IntegerType>(VectorizableTree.front()->Scalars.front()->getType())) {
17888- for (Value *V : *UserIgnoreList) {
17889- auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
17890- auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
17891- unsigned BitWidth1 = NumTypeBits - NumSignBits;
17892- if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
17893- ++BitWidth1;
17894- unsigned BitWidth2 = BitWidth1;
17895- if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
17896- auto Mask = DB->getDemandedBits(cast<Instruction>(V));
17897- BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
17900+ // Convert vector_reduce_add(ZExt(<n x i1>)) to
17901+ // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
17902+ if (all_of(*UserIgnoreList,
17903+ [](Value *V) {
17904+ return cast<Instruction>(V)->getOpcode() == Instruction::Add;
17905+ }) &&
17906+ VectorizableTree.front()->State == TreeEntry::Vectorize &&
17907+ VectorizableTree.front()->getOpcode() == Instruction::ZExt &&
17908+ cast<CastInst>(VectorizableTree.front()->getMainOp())->getSrcTy() ==
17909+ Builder.getInt1Ty()) {
17910+ ReductionBitWidth = 1;
17911+ } else {
17912+ for (Value *V : *UserIgnoreList) {
17913+ auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
17914+ auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
17915+ unsigned BitWidth1 = NumTypeBits - NumSignBits;
17916+ if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
17917+ ++BitWidth1;
17918+ unsigned BitWidth2 = BitWidth1;
17919+ if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
17920+ auto Mask = DB->getDemandedBits(cast<Instruction>(V));
17921+ BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
17922+ }
17923+ ReductionBitWidth =
17924+ std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
1789817925 }
17899- ReductionBitWidth =
17900- std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
17901- }
17902- if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
17903- ReductionBitWidth = 8;
17926+ if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
17927+ ReductionBitWidth = 8;
1790417928
17905- ReductionBitWidth = bit_ceil(ReductionBitWidth);
17929+ ReductionBitWidth = bit_ceil(ReductionBitWidth);
17930+ }
1790617931 }
1790717932 bool IsTopRoot = NodeIdx == 0;
1790817933 while (NodeIdx < VectorizableTree.size() &&
@@ -19758,8 +19783,8 @@ class HorizontalReduction {
1975819783
1975919784 // Estimate cost.
1976019785 InstructionCost TreeCost = V.getTreeCost(VL);
19761- InstructionCost ReductionCost =
19762- getReductionCost( TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF );
19786+ InstructionCost ReductionCost = getReductionCost(
19787+ TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign() );
1976319788 InstructionCost Cost = TreeCost + ReductionCost;
1976419789 LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
1976519790 << " for reduction\n");
@@ -19864,10 +19889,12 @@ class HorizontalReduction {
1986419889 createStrideMask(I, ScalarTyNumElements, VL.size());
1986519890 Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
1986619891 ReducedSubTree = Builder.CreateInsertElement(
19867- ReducedSubTree, emitReduction(Lane, Builder, TTI), I);
19892+ ReducedSubTree,
19893+ emitReduction(Lane, Builder, TTI, RdxRootInst->getType()), I);
1986819894 }
1986919895 } else {
19870- ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI);
19896+ ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI,
19897+ RdxRootInst->getType());
1987119898 }
1987219899 if (ReducedSubTree->getType() != VL.front()->getType()) {
1987319900 assert(ReducedSubTree->getType() != VL.front()->getType() &&
@@ -20048,12 +20075,13 @@ class HorizontalReduction {
2004820075
2004920076private:
2005020077 /// Calculate the cost of a reduction.
20051- InstructionCost getReductionCost(TargetTransformInfo *TTI,
20052- ArrayRef<Value *> ReducedVals,
20053- bool IsCmpSelMinMax, unsigned ReduxWidth ,
20054- FastMathFlags FMF ) {
20078+ InstructionCost getReductionCost(
20079+ TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
20080+ bool IsCmpSelMinMax, FastMathFlags FMF ,
20081+ const std::optional<std::pair<unsigned, bool>> BitwidthAndSign ) {
2005520082 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2005620083 Type *ScalarTy = ReducedVals.front()->getType();
20084+ unsigned ReduxWidth = ReducedVals.size();
2005720085 FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
2005820086 InstructionCost VectorCost = 0, ScalarCost;
2005920087 // If all of the reduced values are constant, the vector cost is 0, since
@@ -20112,8 +20140,22 @@ class HorizontalReduction {
2011220140 VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
2011320141 /*Extract*/ false, TTI::TCK_RecipThroughput);
2011420142 } else {
20115- VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF,
20116- CostKind);
20143+ auto [Bitwidth, IsSigned] =
20144+ BitwidthAndSign.value_or(std::make_pair(0u, false));
20145+ if (RdxKind == RecurKind::Add && Bitwidth == 1) {
20146+ // Represent vector_reduce_add(ZExt(<n x i1>)) to
20147+ // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20148+ auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
20149+ IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
20150+ VectorCost =
20151+ TTI->getCastInstrCost(Instruction::BitCast, IntTy,
20152+ getWidenedType(ScalarTy, ReduxWidth),
20153+ TTI::CastContextHint::None, CostKind) +
20154+ TTI->getIntrinsicInstrCost(ICA, CostKind);
20155+ } else {
20156+ VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
20157+ FMF, CostKind);
20158+ }
2011720159 }
2011820160 }
2011920161 ScalarCost = EvaluateScalarCost([&]() {
@@ -20150,11 +20192,22 @@ class HorizontalReduction {
2015020192
2015120193 /// Emit a horizontal reduction of the vectorized value.
2015220194 Value *emitReduction(Value *VectorizedValue, IRBuilderBase &Builder,
20153- const TargetTransformInfo *TTI) {
20195+ const TargetTransformInfo *TTI, Type *DestTy ) {
2015420196 assert(VectorizedValue && "Need to have a vectorized tree node");
2015520197 assert(RdxKind != RecurKind::FMulAdd &&
2015620198 "A call to the llvm.fmuladd intrinsic is not handled yet");
2015720199
20200+ auto *FTy = cast<FixedVectorType>(VectorizedValue->getType());
20201+ if (FTy->getScalarType() == Builder.getInt1Ty() &&
20202+ RdxKind == RecurKind::Add &&
20203+ DestTy->getScalarType() != FTy->getScalarType()) {
20204+ // Convert vector_reduce_add(ZExt(<n x i1>)) to
20205+ // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20206+ Value *V = Builder.CreateBitCast(
20207+ VectorizedValue, Builder.getIntNTy(FTy->getNumElements()));
20208+ ++NumVectorInstructions;
20209+ return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V);
20210+ }
2015820211 ++NumVectorInstructions;
2015920212 return createSimpleReduction(Builder, VectorizedValue, RdxKind);
2016020213 }
0 commit comments