@@ -22144,53 +22144,16 @@ class HorizontalReduction {
2214422144 }
2214522145
2214622146 Type *ScalarTy = VL.front()->getType();
22147- if (isa<FixedVectorType>(ScalarTy)) {
22148- assert(SLPReVec && "FixedVectorType is not expected.");
22149- unsigned ScalarTyNumElements = getNumElements(ScalarTy);
22150- Value *ReducedSubTree = PoisonValue::get(
22151- getWidenedType(ScalarTy->getScalarType(), ScalarTyNumElements));
22152- for (unsigned I : seq<unsigned>(ScalarTyNumElements)) {
22153- // Do reduction for each lane.
22154- // e.g., do reduce add for
22155- // VL[0] = <4 x Ty> <a, b, c, d>
22156- // VL[1] = <4 x Ty> <e, f, g, h>
22157- // Lane[0] = <2 x Ty> <a, e>
22158- // Lane[1] = <2 x Ty> <b, f>
22159- // Lane[2] = <2 x Ty> <c, g>
22160- // Lane[3] = <2 x Ty> <d, h>
22161- // result[0] = reduce add Lane[0]
22162- // result[1] = reduce add Lane[1]
22163- // result[2] = reduce add Lane[2]
22164- // result[3] = reduce add Lane[3]
22165- SmallVector<int, 16> Mask =
22166- createStrideMask(I, ScalarTyNumElements, VL.size());
22167- Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
22168- Value *Val =
22169- createSingleOp(Builder, *TTI, Lane,
22170- OptReusedScalars && SameScaleFactor
22171- ? SameValuesCounter.front().second
22172- : 1,
22173- Lane->getType()->getScalarType() !=
22174- VL.front()->getType()->getScalarType()
22175- ? V.isSignedMinBitwidthRootNode()
22176- : true,
22177- RdxRootInst->getType());
22178- ReducedSubTree =
22179- Builder.CreateInsertElement(ReducedSubTree, Val, I);
22180- }
22181- VectorizedTree = GetNewVectorizedTree(VectorizedTree, ReducedSubTree);
22182- } else {
22183- Type *VecTy = VectorizedRoot->getType();
22184- Type *RedScalarTy = VecTy->getScalarType();
22185- VectorValuesAndScales.emplace_back(
22186- VectorizedRoot,
22187- OptReusedScalars && SameScaleFactor
22188- ? SameValuesCounter.front().second
22189- : 1,
22190- RedScalarTy != ScalarTy->getScalarType()
22191- ? V.isSignedMinBitwidthRootNode()
22192- : true);
22193- }
22147+ Type *VecTy = VectorizedRoot->getType();
22148+ Type *RedScalarTy = VecTy->getScalarType();
22149+ VectorValuesAndScales.emplace_back(
22150+ VectorizedRoot,
22151+ OptReusedScalars && SameScaleFactor
22152+ ? SameValuesCounter.front().second
22153+ : 1,
22154+ RedScalarTy != ScalarTy->getScalarType()
22155+ ? V.isSignedMinBitwidthRootNode()
22156+ : true);
2219422157
2219522158 // Count vectorized reduced values to exclude them from final reduction.
2219622159 for (Value *RdxVal : VL) {
@@ -22363,9 +22326,35 @@ class HorizontalReduction {
2236322326 Value *createSingleOp(IRBuilderBase &Builder, const TargetTransformInfo &TTI,
2236422327 Value *Vec, unsigned Scale, bool IsSigned,
2236522328 Type *DestTy) {
22366- Value *Rdx = emitReduction(Vec, Builder, &TTI, DestTy);
22367- if (Rdx->getType() != DestTy->getScalarType())
22368- Rdx = Builder.CreateIntCast(Rdx, DestTy->getScalarType(), IsSigned);
22329+ Value *Rdx;
22330+ if (auto *VecTy = dyn_cast<FixedVectorType>(DestTy)) {
22331+ unsigned DestTyNumElements = getNumElements(VecTy);
22332+ unsigned VF = getNumElements(Vec->getType()) / DestTyNumElements;
22333+ Rdx = PoisonValue::get(
22334+ getWidenedType(Vec->getType()->getScalarType(), DestTyNumElements));
22335+ for (unsigned I : seq<unsigned>(DestTyNumElements)) {
22336+ // Do reduction for each lane.
22337+ // e.g., do reduce add for
22338+ // VL[0] = <4 x Ty> <a, b, c, d>
22339+ // VL[1] = <4 x Ty> <e, f, g, h>
22340+ // Lane[0] = <2 x Ty> <a, e>
22341+ // Lane[1] = <2 x Ty> <b, f>
22342+ // Lane[2] = <2 x Ty> <c, g>
22343+ // Lane[3] = <2 x Ty> <d, h>
22344+ // result[0] = reduce add Lane[0]
22345+ // result[1] = reduce add Lane[1]
22346+ // result[2] = reduce add Lane[2]
22347+ // result[3] = reduce add Lane[3]
22348+ SmallVector<int, 16> Mask = createStrideMask(I, DestTyNumElements, VF);
22349+ Value *Lane = Builder.CreateShuffleVector(Vec, Mask);
22350+ Rdx = Builder.CreateInsertElement(
22351+ Rdx, emitReduction(Lane, Builder, &TTI, DestTy), I);
22352+ }
22353+ } else {
22354+ Rdx = emitReduction(Vec, Builder, &TTI, DestTy);
22355+ }
22356+ if (Rdx->getType() != DestTy)
22357+ Rdx = Builder.CreateIntCast(Rdx, DestTy, IsSigned);
2236922358 // Improved analysis for add/fadd/xor reductions with same scale
2237022359 // factor for all operands of reductions. We can emit scalar ops for
2237122360 // them instead.
@@ -22432,30 +22421,32 @@ class HorizontalReduction {
2243222421 case RecurKind::FMul: {
2243322422 unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(RdxKind);
2243422423 if (!AllConsts) {
22435- if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
22436- assert(SLPReVec && "FixedVectorType is not expected.");
22437- unsigned ScalarTyNumElements = VecTy->getNumElements();
22438- for (unsigned I : seq<unsigned>(ReducedVals.size())) {
22439- VectorCost += TTI->getShuffleCost(
22440- TTI::SK_PermuteSingleSrc, VectorTy,
22441- createStrideMask(I, ScalarTyNumElements, ReducedVals.size()));
22442- VectorCost += TTI->getArithmeticReductionCost(RdxOpcode, VecTy, FMF,
22443- CostKind);
22444- }
22445- VectorCost += TTI->getScalarizationOverhead(
22446- VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
22447- /*Extract*/ false, TTI::TCK_RecipThroughput);
22448- } else if (DoesRequireReductionOp) {
22449- Type *RedTy = VectorTy->getElementType();
22450- auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
22451- std::make_pair(RedTy, true));
22452- if (RType == RedTy) {
22453- VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
22454- FMF, CostKind);
22424+ if (DoesRequireReductionOp) {
22425+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
22426+ assert(SLPReVec && "FixedVectorType is not expected.");
22427+ unsigned ScalarTyNumElements = VecTy->getNumElements();
22428+ for (unsigned I : seq<unsigned>(ReducedVals.size())) {
22429+ VectorCost += TTI->getShuffleCost(
22430+ TTI::SK_PermuteSingleSrc, VectorTy,
22431+ createStrideMask(I, ScalarTyNumElements, ReducedVals.size()));
22432+ VectorCost += TTI->getArithmeticReductionCost(RdxOpcode, VecTy,
22433+ FMF, CostKind);
22434+ }
22435+ VectorCost += TTI->getScalarizationOverhead(
22436+ VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
22437+ /*Extract*/ false, TTI::TCK_RecipThroughput);
2245522438 } else {
22456- VectorCost = TTI->getExtendedReductionCost(
22457- RdxOpcode, !IsSigned, RedTy, getWidenedType(RType, ReduxWidth),
22458- FMF, CostKind);
22439+ Type *RedTy = VectorTy->getElementType();
22440+ auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
22441+ std::make_pair(RedTy, true));
22442+ if (RType == RedTy) {
22443+ VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
22444+ FMF, CostKind);
22445+ } else {
22446+ VectorCost = TTI->getExtendedReductionCost(
22447+ RdxOpcode, !IsSigned, RedTy,
22448+ getWidenedType(RType, ReduxWidth), FMF, CostKind);
22449+ }
2245922450 }
2246022451 } else {
2246122452 Type *RedTy = VectorTy->getElementType();
0 commit comments