@@ -8758,13 +8758,6 @@ bool VPRecipeBuilder::getScaledReductions(
87588758 if (!CM.TheLoop ->contains (RdxExitInstr))
87598759 return false ;
87608760
8761- // TODO: Allow scaling reductions when predicating. The select at
8762- // the end of the loop chooses between the phi value and most recent
8763- // reduction result, both of which have different VFs to the active lane
8764- // mask when scaling.
8765- if (CM.blockNeedsPredicationForAnyReason (RdxExitInstr->getParent ()))
8766- return false ;
8767-
87688761 auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
87698762 if (!Update)
87708763 return false ;
@@ -8926,8 +8919,19 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
89268919 isa<VPPartialReductionRecipe>(BinOpRecipe))
89278920 std::swap (BinOp, Accumulator);
89288921
8929- return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp,
8930- Accumulator, Reduction);
8922+ unsigned ReductionOpcode = Reduction->getOpcode ();
8923+ if (CM.blockNeedsPredicationForAnyReason (Reduction->getParent ())) {
8924+ assert ((ReductionOpcode == Instruction::Add ||
8925+ ReductionOpcode == Instruction::Sub) &&
8926+ " Expected an ADD or SUB operation for predicated partial "
8927+ " reductions (because the neutral element in the mask is zero)!" );
8928+ VPValue *Mask = getBlockInMask (Reduction->getParent ());
8929+ VPValue *Zero =
8930+ Plan.getOrAddLiveIn (ConstantInt::get (Reduction->getType (), 0 ));
8931+ BinOp = Builder.createSelect (Mask, BinOp, Zero, Reduction->getDebugLoc ());
8932+ }
8933+ return new VPPartialReductionRecipe (ReductionOpcode, BinOp, Accumulator,
8934+ Reduction);
89318935}
89328936
89338937void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
@@ -9735,7 +9739,11 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
97359739 // beginning of the dedicated latch block.
97369740 auto *OrigExitingVPV = PhiR->getBackedgeValue ();
97379741 auto *NewExitingVPV = PhiR->getBackedgeValue ();
9738- if (!PhiR->isInLoop () && CM.foldTailByMasking ()) {
9742+ // Don't output selects for partial reductions because they have an output
9743+ // with fewer lanes than the VF. So the operands of the select would have
9744+ // different numbers of lanes. Partial reductions mask the input instead.
9745+ if (!PhiR->isInLoop () && CM.foldTailByMasking () &&
9746+ !isa<VPPartialReductionRecipe>(OrigExitingVPV->getDefiningRecipe ())) {
97399747 VPValue *Cond = RecipeBuilder.getBlockInMask (OrigLoop->getHeader ());
97409748 assert (OrigExitingVPV->getDefiningRecipe ()->getParent () != LatchVPBB &&
97419749 " reduction recipe must be defined before latch" );
0 commit comments