@@ -525,6 +525,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
525525 case VPRecipeBase::VPInstructionSC:
526526 case VPRecipeBase::VPReductionEVLSC:
527527 case VPRecipeBase::VPReductionSC:
528+ case VPRecipeBase::VPMulAccumulateReductionSC:
529+ case VPRecipeBase::VPExtendedReductionSC:
528530 case VPRecipeBase::VPReplicateSC:
529531 case VPRecipeBase::VPScalarIVStepsSC:
530532 case VPRecipeBase::VPVectorPointerSC:
@@ -609,13 +611,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
609611 DisjointFlagsTy (bool IsDisjoint) : IsDisjoint(IsDisjoint) {}
610612 };
611613
614+ struct NonNegFlagsTy {
615+ char NonNeg : 1 ;
616+ NonNegFlagsTy (bool IsNonNeg) : NonNeg(IsNonNeg) {}
617+ };
618+
612619private:
613620 struct ExactFlagsTy {
614621 char IsExact : 1 ;
615622 };
616- struct NonNegFlagsTy {
617- char NonNeg : 1 ;
618- };
619623 struct FastMathFlagsTy {
620624 char AllowReassoc : 1 ;
621625 char NoNaNs : 1 ;
@@ -709,6 +713,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
709713 : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
710714 DisjointFlags(DisjointFlags) {}
711715
716+ template <typename IterT>
717+ VPRecipeWithIRFlags (const unsigned char SC, IterT Operands,
718+ NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
719+ : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
720+ NonNegFlags(NonNegFlags) {}
721+
712722protected:
713723 template <typename IterT>
714724 VPRecipeWithIRFlags (const unsigned char SC, IterT Operands,
@@ -728,7 +738,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
728738 R->getVPDefID () == VPRecipeBase::VPReductionEVLSC ||
729739 R->getVPDefID () == VPRecipeBase::VPReplicateSC ||
730740 R->getVPDefID () == VPRecipeBase::VPVectorEndPointerSC ||
731- R->getVPDefID () == VPRecipeBase::VPVectorPointerSC;
741+ R->getVPDefID () == VPRecipeBase::VPVectorPointerSC ||
742+ R->getVPDefID () == VPRecipeBase::VPExtendedReductionSC ||
743+ R->getVPDefID () == VPRecipeBase::VPMulAccumulateReductionSC;
732744 }
733745
734746 static inline bool classof (const VPUser *U) {
@@ -820,6 +832,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
820832
821833 FastMathFlags getFastMathFlags () const ;
822834
835+ // / Returns true if the recipe has non-negative flag.
836+ bool hasNonNegFlag () const { return OpType == OperationType::NonNegOp; }
837+
838+ bool isNonNeg () const {
839+ assert (OpType == OperationType::NonNegOp &&
840+ " recipe doesn't have a NNEG flag" );
841+ return NonNegFlags.NonNeg ;
842+ }
843+
823844 bool hasNoUnsignedWrap () const {
824845 assert (OpType == OperationType::OverflowingBinOp &&
825846 " recipe doesn't have a NUW flag" );
@@ -2373,6 +2394,28 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23732394 setUnderlyingValue (I);
23742395 }
23752396
2397+ // / For VPExtendedReductionRecipe.
2398+ // / Note that the debug location is from the extend.
2399+ VPReductionRecipe (const unsigned char SC, const RecurKind RdxKind,
2400+ ArrayRef<VPValue *> Operands, VPValue *CondOp,
2401+ bool IsOrdered, DebugLoc DL)
2402+ : VPRecipeWithIRFlags(SC, Operands, DL), RdxKind(RdxKind),
2403+ IsOrdered(IsOrdered), IsConditional(CondOp) {
2404+ if (CondOp)
2405+ addOperand (CondOp);
2406+ }
2407+
2408+ // / For VPMulAccumulateReductionRecipe.
2409+ // / Note that the NUW/NSW flags and the debug location are from the Mul.
2410+ VPReductionRecipe (const unsigned char SC, const RecurKind RdxKind,
2411+ ArrayRef<VPValue *> Operands, VPValue *CondOp,
2412+ bool IsOrdered, WrapFlagsTy WrapFlags, DebugLoc DL)
2413+ : VPRecipeWithIRFlags(SC, Operands, WrapFlags, DL), RdxKind(RdxKind),
2414+ IsOrdered(IsOrdered), IsConditional(CondOp) {
2415+ if (CondOp)
2416+ addOperand (CondOp);
2417+ }
2418+
23762419public:
23772420 VPReductionRecipe (RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
23782421 VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
@@ -2381,6 +2424,13 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23812424 ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
23822425 IsOrdered, DL) {}
23832426
2427+ VPReductionRecipe (const RecurKind RdxKind, FastMathFlags FMFs,
2428+ VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
2429+ bool IsOrdered, DebugLoc DL = {})
2430+ : VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, nullptr ,
2431+ ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
2432+ IsOrdered, DL) {}
2433+
23842434 ~VPReductionRecipe () override = default ;
23852435
23862436 VPReductionRecipe *clone () override {
@@ -2391,7 +2441,9 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23912441
23922442 static inline bool classof (const VPRecipeBase *R) {
23932443 return R->getVPDefID () == VPRecipeBase::VPReductionSC ||
2394- R->getVPDefID () == VPRecipeBase::VPReductionEVLSC;
2444+ R->getVPDefID () == VPRecipeBase::VPReductionEVLSC ||
2445+ R->getVPDefID () == VPRecipeBase::VPExtendedReductionSC ||
2446+ R->getVPDefID () == VPRecipeBase::VPMulAccumulateReductionSC;
23952447 }
23962448
23972449 static inline bool classof (const VPUser *U) {
@@ -2471,6 +2523,181 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
24712523 }
24722524};
24732525
2526+ // / A recipe to represent inloop extended reduction operations, performing a
2527+ // / reduction on a extended vector operand into a scalar value, and adding the
2528+ // / result to a chain. This recipe is abstract and needs to be lowered to
2529+ // / concrete recipes before codegen. The operands are {ChainOp, VecOp,
2530+ // / [Condition]}.
2531+ class VPExtendedReductionRecipe : public VPReductionRecipe {
2532+ // / Opcode of the extend recipe will be lowered to.
2533+ Instruction::CastOps ExtOp;
2534+
2535+ Type *ResultTy;
2536+
2537+ // / For cloning VPExtendedReductionRecipe.
2538+ VPExtendedReductionRecipe (VPExtendedReductionRecipe *ExtRed)
2539+ : VPReductionRecipe(
2540+ VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind (),
2541+ {ExtRed->getChainOp (), ExtRed->getVecOp ()}, ExtRed->getCondOp (),
2542+ ExtRed->isOrdered(), ExtRed->getDebugLoc()),
2543+ ExtOp(ExtRed->getExtOpcode ()), ResultTy(ExtRed->getResultType ()) {
2544+ transferFlags (*ExtRed);
2545+ }
2546+
2547+ public:
2548+ VPExtendedReductionRecipe (VPReductionRecipe *R, VPWidenCastRecipe *Ext)
2549+ : VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind (),
2550+ {R->getChainOp (), Ext->getOperand (0 )}, R->getCondOp (),
2551+ R->isOrdered(), Ext->getDebugLoc()),
2552+ ExtOp(Ext->getOpcode ()), ResultTy(Ext->getResultType ()) {
2553+ // Not all WidenCastRecipes contain nneg flag. Need to transfer flags from
2554+ // the original recipe to prevent setting wrong flags.
2555+ transferFlags (*Ext);
2556+ }
2557+
2558+ ~VPExtendedReductionRecipe () override = default ;
2559+
2560+ VPExtendedReductionRecipe *clone () override {
2561+ auto *Copy = new VPExtendedReductionRecipe (this );
2562+ Copy->transferFlags (*this );
2563+ return Copy;
2564+ }
2565+
2566+ VP_CLASSOF_IMPL (VPDef::VPExtendedReductionSC);
2567+
2568+ void execute (VPTransformState &State) override {
2569+ llvm_unreachable (" VPExtendedReductionRecipe should be transform to "
2570+ " VPExtendedRecipe + VPReductionRecipe before execution." );
2571+ };
2572+
2573+ // / Return the cost of VPExtendedReductionRecipe.
2574+ InstructionCost computeCost (ElementCount VF,
2575+ VPCostContext &Ctx) const override ;
2576+
2577+ #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2578+ // / Print the recipe.
2579+ void print (raw_ostream &O, const Twine &Indent,
2580+ VPSlotTracker &SlotTracker) const override ;
2581+ #endif
2582+
2583+ // / The scalar type after extending.
2584+ Type *getResultType () const { return ResultTy; }
2585+
2586+ // / Is the extend ZExt?
2587+ bool isZExt () const { return getExtOpcode () == Instruction::ZExt; }
2588+
2589+ // / The opcode of extend recipe.
2590+ Instruction::CastOps getExtOpcode () const { return ExtOp; }
2591+ };
2592+
2593+ // / A recipe to represent inloop MulAccumulateReduction operations, performing a
2594+ // / reduction.add on the result of vector operands (might be extended)
2595+ // / multiplication into a scalar value, and adding the result to a chain. This
2596+ // / recipe is abstract and needs to be lowered to concrete recipes before
2597+ // / codegen. The operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
2598+ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
2599+ // / Opcode of the extend recipe.
2600+ Instruction::CastOps ExtOp;
2601+
2602+ // / Non-neg flag of the extend recipe.
2603+ bool IsNonNeg = false ;
2604+
2605+ Type *ResultTy;
2606+
2607+ // / For cloning VPMulAccumulateReductionRecipe.
2608+ VPMulAccumulateReductionRecipe (VPMulAccumulateReductionRecipe *MulAcc)
2609+ : VPReductionRecipe(
2610+ VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind (),
2611+ {MulAcc->getChainOp (), MulAcc->getVecOp0 (), MulAcc->getVecOp1 ()},
2612+ MulAcc->getCondOp (), MulAcc->isOrdered(),
2613+ WrapFlagsTy(MulAcc->hasNoUnsignedWrap (), MulAcc->hasNoSignedWrap()),
2614+ MulAcc->getDebugLoc()),
2615+ ExtOp(MulAcc->getExtOpcode ()), IsNonNeg(MulAcc->isNonNeg ()),
2616+ ResultTy(MulAcc->getResultType ()) {}
2617+
2618+ public:
2619+ VPMulAccumulateReductionRecipe (VPReductionRecipe *R, VPWidenRecipe *Mul,
2620+ VPWidenCastRecipe *Ext0,
2621+ VPWidenCastRecipe *Ext1, Type *ResultTy)
2622+ : VPReductionRecipe(
2623+ VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind (),
2624+ {R->getChainOp (), Ext0->getOperand (0 ), Ext1->getOperand (0 )},
2625+ R->getCondOp (), R->isOrdered(),
2626+ WrapFlagsTy(Mul->hasNoUnsignedWrap (), Mul->hasNoSignedWrap()),
2627+ R->getDebugLoc()),
2628+ ExtOp(Ext0->getOpcode ()), ResultTy(ResultTy) {
2629+ assert (RecurrenceDescriptor::getOpcode (getRecurrenceKind ()) ==
2630+ Instruction::Add &&
2631+ " The reduction instruction in MulAccumulateteReductionRecipe must "
2632+ " be Add" );
2633+ // Only set the non-negative flag if the original recipe contains.
2634+ if (Ext0->hasNonNegFlag ())
2635+ IsNonNeg = Ext0->isNonNeg ();
2636+ }
2637+
2638+ VPMulAccumulateReductionRecipe (VPReductionRecipe *R, VPWidenRecipe *Mul)
2639+ : VPReductionRecipe(
2640+ VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind (),
2641+ {R->getChainOp (), Mul->getOperand (0 ), Mul->getOperand (1 )},
2642+ R->getCondOp (), R->isOrdered(),
2643+ WrapFlagsTy(Mul->hasNoUnsignedWrap (), Mul->hasNoSignedWrap()),
2644+ R->getDebugLoc()),
2645+ ExtOp(Instruction::CastOps::CastOpsEnd) {
2646+ assert (RecurrenceDescriptor::getOpcode (getRecurrenceKind ()) ==
2647+ Instruction::Add &&
2648+ " The reduction instruction in MulAccumulateReductionRecipe must be "
2649+ " Add" );
2650+ }
2651+
2652+ ~VPMulAccumulateReductionRecipe () override = default ;
2653+
2654+ VPMulAccumulateReductionRecipe *clone () override {
2655+ auto *Copy = new VPMulAccumulateReductionRecipe (this );
2656+ Copy->transferFlags (*this );
2657+ return Copy;
2658+ }
2659+
2660+ VP_CLASSOF_IMPL (VPDef::VPMulAccumulateReductionSC);
2661+
2662+ void execute (VPTransformState &State) override {
2663+ llvm_unreachable (" VPMulAccumulateReductionRecipe should transform to "
2664+ " VPWidenCastRecipe + "
2665+ " VPWidenRecipe + VPReductionRecipe before execution" );
2666+ }
2667+
2668+ // / Return the cost of VPMulAccumulateReductionRecipe.
2669+ InstructionCost computeCost (ElementCount VF,
2670+ VPCostContext &Ctx) const override ;
2671+
2672+ #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2673+ // / Print the recipe.
2674+ void print (raw_ostream &O, const Twine &Indent,
2675+ VPSlotTracker &SlotTracker) const override ;
2676+ #endif
2677+
2678+ Type *getResultType () const {
2679+ assert (isExtended () && " Only support getResultType when this recipe "
2680+ " contains implicit extend." );
2681+ return ResultTy;
2682+ }
2683+
2684+ // / The VPValue of the vector value to be extended and reduced.
2685+ VPValue *getVecOp0 () const { return getOperand (1 ); }
2686+ VPValue *getVecOp1 () const { return getOperand (2 ); }
2687+
2688+ // / Return if this MulAcc recipe contains extended operands.
2689+ bool isExtended () const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
2690+
2691+ // / Return the opcode of the extends for the operands.
2692+ Instruction::CastOps getExtOpcode () const { return ExtOp; }
2693+
2694+ // / Return if the operands are zero extended.
2695+ bool isZExt () const { return ExtOp == Instruction::CastOps::ZExt; }
2696+
2697+ // / Return the non negative flag of the ext recipe.
2698+ bool isNonNeg () const { return IsNonNeg; }
2699+ };
2700+
24742701// / VPReplicateRecipe replicates a given instruction producing multiple scalar
24752702// / copies of the original scalar type, one per lane, instead of producing a
24762703// / single copy of widened type for all lanes. If the instruction is known to be
0 commit comments