diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 4b661ad40f2d4..dd87d34d1f01a 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1371,6 +1371,18 @@ class BoUpSLP { return MinBWs.at(VectorizableTree.front().get()).second; } + /// Returns reduction bitwidth and signedness, if it does not match the + /// original requested size. + std::optional> getReductionBitWidthAndSign() const { + if (ReductionBitWidth == 0 || + ReductionBitWidth == + DL->getTypeSizeInBits( + VectorizableTree.front()->Scalars.front()->getType())) + return std::nullopt; + return std::make_pair(ReductionBitWidth, + MinBWs.at(VectorizableTree.front().get()).second); + } + /// Builds external uses of the vectorized scalars, i.e. the list of /// vectorized scalars to be extracted, their lanes and their scalar users. \p /// ExternallyUsedValues contains additional list of external uses to handle @@ -17887,24 +17899,37 @@ void BoUpSLP::computeMinimumValueSizes() { // Add reduction ops sizes, if any. if (UserIgnoreList && isa(VectorizableTree.front()->Scalars.front()->getType())) { - for (Value *V : *UserIgnoreList) { - auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT); - auto NumTypeBits = DL->getTypeSizeInBits(V->getType()); - unsigned BitWidth1 = NumTypeBits - NumSignBits; - if (!isKnownNonNegative(V, SimplifyQuery(*DL))) - ++BitWidth1; - unsigned BitWidth2 = BitWidth1; - if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) { - auto Mask = DB->getDemandedBits(cast(V)); - BitWidth2 = Mask.getBitWidth() - Mask.countl_zero(); + // Convert vector_reduce_add(ZExt()) to ZExtOrTrunc(ctpop(bitcast to in)). + if (all_of(*UserIgnoreList, + [](Value *V) { + return cast(V)->getOpcode() == Instruction::Add; + }) && + VectorizableTree.front()->State == TreeEntry::Vectorize && + VectorizableTree.front()->getOpcode() == Instruction::ZExt && + cast(VectorizableTree.front()->getMainOp())->getSrcTy() == + Builder.getInt1Ty()) { + ReductionBitWidth = 1; + } else { + for (Value *V : *UserIgnoreList) { + unsigned NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT); + TypeSize NumTypeBits = DL->getTypeSizeInBits(V->getType()); + unsigned BitWidth1 = NumTypeBits - NumSignBits; + if (!isKnownNonNegative(V, SimplifyQuery(*DL))) + ++BitWidth1; + unsigned BitWidth2 = BitWidth1; + if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) { + APInt Mask = DB->getDemandedBits(cast(V)); + BitWidth2 = Mask.getBitWidth() - Mask.countl_zero(); + } + ReductionBitWidth = + std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth); } - ReductionBitWidth = - std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth); - } - if (ReductionBitWidth < 8 && ReductionBitWidth > 1) - ReductionBitWidth = 8; + if (ReductionBitWidth < 8 && ReductionBitWidth > 1) + ReductionBitWidth = 8; - ReductionBitWidth = bit_ceil(ReductionBitWidth); + ReductionBitWidth = bit_ceil(ReductionBitWidth); + } } bool IsTopRoot = NodeIdx == 0; while (NodeIdx < VectorizableTree.size() && @@ -19760,8 +19785,8 @@ class HorizontalReduction { // Estimate cost. InstructionCost TreeCost = V.getTreeCost(VL); - InstructionCost ReductionCost = - getReductionCost(TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF); + InstructionCost ReductionCost = getReductionCost( + TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign()); InstructionCost Cost = TreeCost + ReductionCost; LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for reduction\n"); @@ -19866,10 +19891,12 @@ class HorizontalReduction { createStrideMask(I, ScalarTyNumElements, VL.size()); Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask); ReducedSubTree = Builder.CreateInsertElement( - ReducedSubTree, emitReduction(Lane, Builder, TTI), I); + ReducedSubTree, + emitReduction(Lane, Builder, TTI, RdxRootInst->getType()), I); } } else { - ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI); + ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI, + RdxRootInst->getType()); } if (ReducedSubTree->getType() != VL.front()->getType()) { assert(ReducedSubTree->getType() != VL.front()->getType() && @@ -20050,12 +20077,13 @@ class HorizontalReduction { private: /// Calculate the cost of a reduction. - InstructionCost getReductionCost(TargetTransformInfo *TTI, - ArrayRef ReducedVals, - bool IsCmpSelMinMax, unsigned ReduxWidth, - FastMathFlags FMF) { + InstructionCost getReductionCost( + TargetTransformInfo *TTI, ArrayRef ReducedVals, + bool IsCmpSelMinMax, FastMathFlags FMF, + const std::optional> BitwidthAndSign) { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; Type *ScalarTy = ReducedVals.front()->getType(); + unsigned ReduxWidth = ReducedVals.size(); FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth); InstructionCost VectorCost = 0, ScalarCost; // If all of the reduced values are constant, the vector cost is 0, since @@ -20114,8 +20142,22 @@ class HorizontalReduction { VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true, /*Extract*/ false, TTI::TCK_RecipThroughput); } else { - VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF, - CostKind); + auto [Bitwidth, IsSigned] = + BitwidthAndSign.value_or(std::make_pair(0u, false)); + if (RdxKind == RecurKind::Add && Bitwidth == 1) { + // Represent vector_reduce_add(ZExt()) to + // ZExtOrTrunc(ctpop(bitcast to in)). + auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth); + IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF); + VectorCost = + TTI->getCastInstrCost(Instruction::BitCast, IntTy, + getWidenedType(ScalarTy, ReduxWidth), + TTI::CastContextHint::None, CostKind) + + TTI->getIntrinsicInstrCost(ICA, CostKind); + } else { + VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, + FMF, CostKind); + } } } ScalarCost = EvaluateScalarCost([&]() { @@ -20152,11 +20194,22 @@ class HorizontalReduction { /// Emit a horizontal reduction of the vectorized value. Value *emitReduction(Value *VectorizedValue, IRBuilderBase &Builder, - const TargetTransformInfo *TTI) { + const TargetTransformInfo *TTI, Type *DestTy) { assert(VectorizedValue && "Need to have a vectorized tree node"); assert(RdxKind != RecurKind::FMulAdd && "A call to the llvm.fmuladd intrinsic is not handled yet"); + auto *FTy = cast(VectorizedValue->getType()); + if (FTy->getScalarType() == Builder.getInt1Ty() && + RdxKind == RecurKind::Add && + DestTy->getScalarType() != FTy->getScalarType()) { + // Convert vector_reduce_add(ZExt()) to + // ZExtOrTrunc(ctpop(bitcast to in)). + Value *V = Builder.CreateBitCast( + VectorizedValue, Builder.getIntNTy(FTy->getNumElements())); + ++NumVectorInstructions; + return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V); + } ++NumVectorInstructions; return createSimpleReduction(Builder, VectorizedValue, RdxKind); } diff --git a/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll b/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll index ecf85159efdfb..f00b846bf4f5b 100644 --- a/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll @@ -11,8 +11,9 @@ define i16 @test(i16 %call37) { ; CHECK-NEXT: [[TMP2:%.*]] = icmp slt <8 x i16> [[SHUFFLE]], zeroinitializer ; CHECK-NEXT: [[TMP3:%.*]] = icmp sgt <8 x i16> [[SHUFFLE]], zeroinitializer ; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <8 x i1> [[TMP2]], <8 x i1> [[TMP3]], <8 x i32> -; CHECK-NEXT: [[TMP5:%.*]] = zext <8 x i1> [[TMP4]] to <8 x i16> -; CHECK-NEXT: [[TMP6:%.*]] = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> [[TMP5]]) +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <8 x i1> [[TMP4]] to i8 +; CHECK-NEXT: [[TMP7:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP8]]) +; CHECK-NEXT: [[TMP6:%.*]] = zext i8 [[TMP7]] to i16 ; CHECK-NEXT: [[OP_RDX:%.*]] = add i16 [[TMP6]], 0 ; CHECK-NEXT: ret i16 [[OP_RDX]] ; diff --git a/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll b/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll index 89fcc7e983749..303e31dfa5e64 100644 --- a/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll +++ b/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll @@ -14,8 +14,9 @@ define i32 @test(i32 %a, i8 %b, i8 %c) { ; CHECK-NEXT: [[TMP8:%.*]] = zext <4 x i8> [[TMP2]] to <4 x i16> ; CHECK-NEXT: [[TMP9:%.*]] = sext <4 x i8> [[TMP4]] to <4 x i16> ; CHECK-NEXT: [[TMP5:%.*]] = icmp sle <4 x i16> [[TMP8]], [[TMP9]] -; CHECK-NEXT: [[TMP6:%.*]] = zext <4 x i1> [[TMP5]] to <4 x i32> -; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]]) +; CHECK-NEXT: [[TMP10:%.*]] = bitcast <4 x i1> [[TMP5]] to i4 +; CHECK-NEXT: [[TMP11:%.*]] = call i4 @llvm.ctpop.i4(i4 [[TMP10]]) +; CHECK-NEXT: [[TMP7:%.*]] = zext i4 [[TMP11]] to i32 ; CHECK-NEXT: [[OP_RDX:%.*]] = add i32 [[TMP7]], [[A]] ; CHECK-NEXT: ret i32 [[OP_RDX]] ;