@@ -2331,21 +2331,21 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23312331// / vector operand are added together and passed to the next iteration as the
23322332// / next accumulator. After the loop body, the accumulator is reduced to a
23332333// / scalar value.
2334- class VPPartialReductionRecipe : public VPSingleDefRecipe {
2334+ class VPPartialReductionRecipe : public VPReductionRecipe {
23352335 unsigned Opcode;
23362336
23372337public:
23382338 VPPartialReductionRecipe (Instruction *ReductionInst, VPValue *Op0,
2339- VPValue *Op1)
2340- : VPPartialReductionRecipe(ReductionInst->getOpcode (), Op0, Op1,
2339+ VPValue *Op1, VPValue *Cond )
2340+ : VPPartialReductionRecipe(ReductionInst->getOpcode (), Op0, Op1, Cond,
23412341 ReductionInst) {}
23422342 VPPartialReductionRecipe (unsigned Opcode, VPValue *Op0, VPValue *Op1,
2343- Instruction *ReductionInst = nullptr )
2344- : VPSingleDefRecipe(VPDef::VPPartialReductionSC,
2345- ArrayRef<VPValue *>({Op0, Op1}), ReductionInst),
2343+ VPValue *Cond, Instruction *ReductionInst = nullptr )
2344+ : VPReductionRecipe(VPDef::VPPartialReductionSC, RecurKind::Add,
2345+ FastMathFlags (), ReductionInst,
2346+ ArrayRef<VPValue *>({Op0, Op1}), Cond, false, {}),
23462347 Opcode(Opcode) {
2347- [[maybe_unused]] auto *AccumulatorRecipe =
2348- getOperand (1 )->getDefiningRecipe ();
2348+ [[maybe_unused]] auto *AccumulatorRecipe = getChainOp ()->getDefiningRecipe ();
23492349 assert ((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
23502350 isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
23512351 " Unexpected operand order for partial reduction recipe" );
@@ -2354,7 +2354,7 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe {
23542354
23552355 VPPartialReductionRecipe *clone () override {
23562356 return new VPPartialReductionRecipe (Opcode, getOperand (0 ), getOperand (1 ),
2357- getUnderlyingInstr ());
2357+ getCondOp (), getUnderlyingInstr ());
23582358 }
23592359
23602360 VP_CLASSOF_IMPL (VPDef::VPPartialReductionSC)
@@ -2369,14 +2369,16 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe {
23692369 // / Get the binary op's opcode.
23702370 unsigned getOpcode () const { return Opcode; }
23712371
2372+ // / Get the binary op this reduction is applied to.
2373+ VPValue *getBinOp () const { return getOperand (1 ); }
2374+
23722375#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
23732376 // / Print the recipe.
23742377 void print (raw_ostream &O, const Twine &Indent,
23752378 VPSlotTracker &SlotTracker) const override ;
23762379#endif
23772380};
23782381
2379-
23802382// / A recipe to represent inloop reduction operations with vector-predication
23812383// / intrinsics, performing a reduction on a vector operand with the explicit
23822384// / vector length (EVL) into a scalar value, and adding the result to a chain.
@@ -2497,6 +2499,9 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
24972499
24982500 Type *ResultTy;
24992501
2502+ // / If the reduction this is based on is a partial reduction.
2503+ bool IsPartialReduction = false ;
2504+
25002505 // / For cloning VPMulAccumulateReductionRecipe.
25012506 VPMulAccumulateReductionRecipe (VPMulAccumulateReductionRecipe *MulAcc)
25022507 : VPReductionRecipe(
@@ -2506,7 +2511,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
25062511 WrapFlagsTy(MulAcc->hasNoUnsignedWrap (), MulAcc->hasNoSignedWrap()),
25072512 MulAcc->getDebugLoc()),
25082513 ExtOp(MulAcc->getExtOpcode ()), IsNonNeg(MulAcc->isNonNeg ()),
2509- ResultTy(MulAcc->getResultType ()) {}
2514+ ResultTy(MulAcc->getResultType ()),
2515+ IsPartialReduction(MulAcc->isPartialReduction ()) {}
25102516
25112517public:
25122518 VPMulAccumulateReductionRecipe (VPReductionRecipe *R, VPWidenRecipe *Mul,
@@ -2519,7 +2525,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
25192525 WrapFlagsTy(Mul->hasNoUnsignedWrap (), Mul->hasNoSignedWrap()),
25202526 R->getDebugLoc()),
25212527 ExtOp(Ext0->getOpcode ()), IsNonNeg(Ext0->isNonNeg ()),
2522- ResultTy(ResultTy) {
2528+ ResultTy(ResultTy),
2529+ IsPartialReduction(isa<VPPartialReductionRecipe>(R)) {
25232530 assert (RecurrenceDescriptor::getOpcode (getRecurrenceKind ()) ==
25242531 Instruction::Add &&
25252532 " The reduction instruction in MulAccumulateteReductionRecipe must "
@@ -2590,6 +2597,9 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
25902597
25912598 // / Return the non negative flag of the ext recipe.
25922599 bool isNonNeg () const { return IsNonNeg; }
2600+
2601+ // / Return if the underlying reduction recipe is a partial reduction.
2602+ bool isPartialReduction () const { return IsPartialReduction; }
25932603};
25942604
25952605// / VPReplicateRecipe replicates a given instruction producing multiple scalar
0 commit comments