Skip to content

Commit 517d725

Browse files
[LV] Move condition to VPPartialReductionRecipe::execute (#166136)
This means that VPExpressions will now be constructed for VPPartialReductionRecipe's when the loop has tail-folding predication. Note that control-flow (if/else) predication is not yet handled for partial reductions, because of the way partial reductions are recognised and built up.
1 parent 8e3188a commit 517d725

File tree

6 files changed

+424
-664
lines changed

6 files changed

+424
-664
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8226,15 +8226,8 @@ VPRecipeBuilder::tryToCreatePartialReduction(VPInstruction *Reduction,
82268226
}
82278227

82288228
VPValue *Cond = nullptr;
8229-
if (CM.blockNeedsPredicationForAnyReason(ReductionI->getParent())) {
8230-
assert((ReductionOpcode == Instruction::Add ||
8231-
ReductionOpcode == Instruction::Sub) &&
8232-
"Expected an ADD or SUB operation for predicated partial "
8233-
"reductions (because the neutral element in the mask is zero)!");
8229+
if (CM.blockNeedsPredicationForAnyReason(ReductionI->getParent()))
82348230
Cond = getBlockInMask(Builder.getInsertBlock());
8235-
VPValue *Zero = Plan.getConstantInt(ReductionI->getType(), 0);
8236-
BinOp = Builder.createSelect(Cond, BinOp, Zero, Reduction->getDebugLoc());
8237-
}
82388231
return new VPPartialReductionRecipe(ReductionOpcode, Accumulator, BinOp, Cond,
82398232
ScaleFactor, ReductionI);
82408233
}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -311,18 +311,27 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
311311
std::optional<unsigned> Opcode;
312312
VPValue *Op = getVecOp();
313313
uint64_t MulConst;
314+
315+
InstructionCost CondCost = 0;
316+
if (isConditional()) {
317+
CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
318+
auto *VecTy = Ctx.Types.inferScalarType(Op);
319+
auto *CondTy = Ctx.Types.inferScalarType(getCondOp());
320+
CondCost = Ctx.TTI.getCmpSelInstrCost(Instruction::Select, VecTy, CondTy,
321+
Pred, Ctx.CostKind);
322+
}
323+
314324
// If the partial reduction is predicated, a select will be operand 1.
315325
// If it isn't predicated and the mul isn't operating on a constant, then it
316326
// should have been turned into a VPExpressionRecipe.
317327
// FIXME: Replace the entire function with this once all partial reduction
318328
// variants are bundled into VPExpressionRecipe.
319-
if (!match(Op, m_Select(m_VPValue(), m_VPValue(Op), m_VPValue())) &&
320-
!match(Op, m_Mul(m_VPValue(), m_ConstantInt(MulConst)))) {
329+
if (!match(Op, m_Mul(m_VPValue(), m_ConstantInt(MulConst)))) {
321330
auto *PhiType = Ctx.Types.inferScalarType(getChainOp());
322331
auto *InputType = Ctx.Types.inferScalarType(getVecOp());
323-
return Ctx.TTI.getPartialReductionCost(getOpcode(), InputType, InputType,
324-
PhiType, VF, TTI::PR_None,
325-
TTI::PR_None, {}, Ctx.CostKind);
332+
return CondCost + Ctx.TTI.getPartialReductionCost(
333+
getOpcode(), InputType, InputType, PhiType, VF,
334+
TTI::PR_None, TTI::PR_None, {}, Ctx.CostKind);
326335
}
327336

328337
VPRecipeBase *OpR = Op->getDefiningRecipe();
@@ -381,12 +390,13 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
381390
} else if (auto Widen = dyn_cast<VPWidenRecipe>(OpR)) {
382391
HandleWiden(Widen);
383392
} else if (auto Reduction = dyn_cast<VPPartialReductionRecipe>(OpR)) {
384-
return Reduction->computeCost(VF, Ctx);
393+
return CondCost + Reduction->computeCost(VF, Ctx);
385394
}
386395
auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
387-
return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
388-
PhiType, VF, ExtAType, ExtBType,
389-
Opcode, Ctx.CostKind);
396+
return CondCost + Ctx.TTI.getPartialReductionCost(
397+
getOpcode(), InputTypeA, InputTypeB, PhiType, VF,
398+
ExtAType, ExtBType, Opcode, Ctx.CostKind);
399+
;
390400
}
391401

392402
void VPPartialReductionRecipe::execute(VPTransformState &State) {
@@ -395,12 +405,18 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
395405
assert(getOpcode() == Instruction::Add &&
396406
"Unhandled partial reduction opcode");
397407

398-
Value *BinOpVal = State.get(getOperand(1));
399-
Value *PhiVal = State.get(getOperand(0));
408+
Value *BinOpVal = State.get(getVecOp());
409+
Value *PhiVal = State.get(getChainOp());
400410
assert(PhiVal && BinOpVal && "Phi and Mul must be set");
401411

402412
Type *RetTy = PhiVal->getType();
403413

414+
if (isConditional()) {
415+
Value *Cond = State.get(getCondOp());
416+
Value *Zero = ConstantInt::get(BinOpVal->getType(), 0);
417+
BinOpVal = Builder.CreateSelect(Cond, BinOpVal, Zero);
418+
}
419+
404420
CallInst *V =
405421
Builder.CreateIntrinsic(RetTy, Intrinsic::vector_partial_reduce_add,
406422
{PhiVal, BinOpVal}, nullptr, "partial.reduce");

0 commit comments

Comments
 (0)