@@ -517,6 +517,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
517517 case VPRecipeBase::VPInstructionSC:
518518 case VPRecipeBase::VPReductionEVLSC:
519519 case VPRecipeBase::VPReductionSC:
520+ case VPRecipeBase::VPMulAccumulateReductionSC:
521+ case VPRecipeBase::VPExtendedReductionSC:
520522 case VPRecipeBase::VPReplicateSC:
521523 case VPRecipeBase::VPScalarIVStepsSC:
522524 case VPRecipeBase::VPVectorPointerSC:
@@ -601,13 +603,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
601603 DisjointFlagsTy (bool IsDisjoint) : IsDisjoint(IsDisjoint) {}
602604 };
603605
606+ struct NonNegFlagsTy {
607+ char NonNeg : 1 ;
608+ NonNegFlagsTy (bool IsNonNeg) : NonNeg(IsNonNeg) {}
609+ };
610+
604611private:
605612 struct ExactFlagsTy {
606613 char IsExact : 1 ;
607614 };
608- struct NonNegFlagsTy {
609- char NonNeg : 1 ;
610- };
611615 struct FastMathFlagsTy {
612616 char AllowReassoc : 1 ;
613617 char NoNaNs : 1 ;
@@ -697,6 +701,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
697701 : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
698702 DisjointFlags(DisjointFlags) {}
699703
704+ template <typename IterT>
705+ VPRecipeWithIRFlags (const unsigned char SC, IterT Operands,
706+ NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
707+ : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
708+ NonNegFlags(NonNegFlags) {}
709+
700710protected:
701711 VPRecipeWithIRFlags (const unsigned char SC, ArrayRef<VPValue *> Operands,
702712 GEPNoWrapFlags GEPFlags, DebugLoc DL = {})
@@ -715,7 +725,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
715725 R->getVPDefID () == VPRecipeBase::VPReductionEVLSC ||
716726 R->getVPDefID () == VPRecipeBase::VPReplicateSC ||
717727 R->getVPDefID () == VPRecipeBase::VPVectorEndPointerSC ||
718- R->getVPDefID () == VPRecipeBase::VPVectorPointerSC;
728+ R->getVPDefID () == VPRecipeBase::VPVectorPointerSC ||
729+ R->getVPDefID () == VPRecipeBase::VPExtendedReductionSC ||
730+ R->getVPDefID () == VPRecipeBase::VPMulAccumulateReductionSC;
719731 }
720732
721733 static inline bool classof (const VPUser *U) {
@@ -812,6 +824,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
812824
813825 FastMathFlags getFastMathFlags () const ;
814826
827+ // / Returns true if the recipe has non-negative flag.
828+ bool hasNonNegFlag () const { return OpType == OperationType::NonNegOp; }
829+
830+ bool isNonNeg () const {
831+ assert (OpType == OperationType::NonNegOp &&
832+ " recipe doesn't have a NNEG flag" );
833+ return NonNegFlags.NonNeg ;
834+ }
835+
815836 bool hasNoUnsignedWrap () const {
816837 assert (OpType == OperationType::OverflowingBinOp &&
817838 " recipe doesn't have a NUW flag" );
@@ -1289,10 +1310,21 @@ class VPWidenRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
12891310 : VPRecipeWithIRFlags(VPDefOpcode, Operands, I), VPIRMetadata(I),
12901311 Opcode (I.getOpcode()) {}
12911312
1313+ template <typename IterT>
1314+ VPWidenRecipe (unsigned VPDefOpcode, unsigned Opcode,
1315+ iterator_range<IterT> Operands, bool NUW, bool NSW, DebugLoc DL)
1316+ : VPRecipeWithIRFlags(VPDefOpcode, Operands, WrapFlagsTy(NUW, NSW), DL),
1317+ Opcode(Opcode) {}
1318+
12921319public:
12931320 VPWidenRecipe (Instruction &I, ArrayRef<VPValue *> Operands)
12941321 : VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
12951322
1323+ template <typename IterT>
1324+ VPWidenRecipe (unsigned Opcode, iterator_range<IterT> Operands, bool NUW,
1325+ bool NSW, DebugLoc DL)
1326+ : VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands, NUW, NSW, DL) {}
1327+
12961328 ~VPWidenRecipe () override = default ;
12971329
12981330 VPWidenRecipe *clone () override {
@@ -1337,8 +1369,15 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
13371369 " opcode of underlying cast doesn't match" );
13381370 }
13391371
1340- VPWidenCastRecipe (Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy)
1341- : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op), VPIRMetadata(),
1372+ VPWidenCastRecipe (Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1373+ DebugLoc DL = {})
1374+ : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), VPIRMetadata(),
1375+ Opcode(Opcode), ResultTy(ResultTy) {}
1376+
1377+ VPWidenCastRecipe (Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1378+ bool IsNonNeg, DebugLoc DL = {})
1379+ : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
1380+ DL),
13421381 Opcode(Opcode), ResultTy(ResultTy) {}
13431382
13441383 ~VPWidenCastRecipe () override = default ;
@@ -2381,6 +2420,28 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23812420 setUnderlyingValue (I);
23822421 }
23832422
2423+ // / For VPExtendedReductionRecipe.
2424+ // / Note that the debug location is from the extend.
2425+ VPReductionRecipe (const unsigned char SC, const RecurKind RdxKind,
2426+ ArrayRef<VPValue *> Operands, VPValue *CondOp,
2427+ bool IsOrdered, DebugLoc DL)
2428+ : VPRecipeWithIRFlags(SC, Operands, DL), RdxKind(RdxKind),
2429+ IsOrdered(IsOrdered), IsConditional(CondOp) {
2430+ if (CondOp)
2431+ addOperand (CondOp);
2432+ }
2433+
2434+ // / For VPMulAccumulateReductionRecipe.
2435+ // / Note that the NUW/NSW flags and the debug location are from the Mul.
2436+ VPReductionRecipe (const unsigned char SC, const RecurKind RdxKind,
2437+ ArrayRef<VPValue *> Operands, VPValue *CondOp,
2438+ bool IsOrdered, WrapFlagsTy WrapFlags, DebugLoc DL)
2439+ : VPRecipeWithIRFlags(SC, Operands, WrapFlags, DL), RdxKind(RdxKind),
2440+ IsOrdered(IsOrdered), IsConditional(CondOp) {
2441+ if (CondOp)
2442+ addOperand (CondOp);
2443+ }
2444+
23842445public:
23852446 VPReductionRecipe (RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
23862447 VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
@@ -2389,6 +2450,13 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23892450 ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
23902451 IsOrdered, DL) {}
23912452
2453+ VPReductionRecipe (const RecurKind RdxKind, FastMathFlags FMFs,
2454+ VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
2455+ bool IsOrdered, DebugLoc DL = {})
2456+ : VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, nullptr ,
2457+ ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
2458+ IsOrdered, DL) {}
2459+
23922460 ~VPReductionRecipe () override = default ;
23932461
23942462 VPReductionRecipe *clone () override {
@@ -2399,7 +2467,9 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23992467
24002468 static inline bool classof (const VPRecipeBase *R) {
24012469 return R->getVPDefID () == VPRecipeBase::VPReductionSC ||
2402- R->getVPDefID () == VPRecipeBase::VPReductionEVLSC;
2470+ R->getVPDefID () == VPRecipeBase::VPReductionEVLSC ||
2471+ R->getVPDefID () == VPRecipeBase::VPExtendedReductionSC ||
2472+ R->getVPDefID () == VPRecipeBase::VPMulAccumulateReductionSC;
24032473 }
24042474
24052475 static inline bool classof (const VPUser *U) {
@@ -2538,6 +2608,181 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
25382608 }
25392609};
25402610
2611+ // / A recipe to represent inloop extended reduction operations, performing a
2612+ // / reduction on a extended vector operand into a scalar value, and adding the
2613+ // / result to a chain. This recipe is abstract and needs to be lowered to
2614+ // / concrete recipes before codegen. The operands are {ChainOp, VecOp,
2615+ // / [Condition]}.
2616+ class VPExtendedReductionRecipe : public VPReductionRecipe {
2617+ // / Opcode of the extend recipe will be lowered to.
2618+ Instruction::CastOps ExtOp;
2619+
2620+ Type *ResultTy;
2621+
2622+ // / For cloning VPExtendedReductionRecipe.
2623+ VPExtendedReductionRecipe (VPExtendedReductionRecipe *ExtRed)
2624+ : VPReductionRecipe(
2625+ VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind (),
2626+ {ExtRed->getChainOp (), ExtRed->getVecOp ()}, ExtRed->getCondOp (),
2627+ ExtRed->isOrdered(), ExtRed->getDebugLoc()),
2628+ ExtOp(ExtRed->getExtOpcode ()), ResultTy(ExtRed->getResultType ()) {
2629+ transferFlags (*ExtRed);
2630+ }
2631+
2632+ public:
2633+ VPExtendedReductionRecipe (VPReductionRecipe *R, VPWidenCastRecipe *Ext)
2634+ : VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind (),
2635+ {R->getChainOp (), Ext->getOperand (0 )}, R->getCondOp (),
2636+ R->isOrdered(), Ext->getDebugLoc()),
2637+ ExtOp(Ext->getOpcode ()), ResultTy(Ext->getResultType ()) {
2638+ // Not all WidenCastRecipes contain nneg flag. Need to transfer flags from
2639+ // the original recipe to prevent setting wrong flags.
2640+ transferFlags (*Ext);
2641+ }
2642+
2643+ ~VPExtendedReductionRecipe () override = default ;
2644+
2645+ VPExtendedReductionRecipe *clone () override {
2646+ auto *Copy = new VPExtendedReductionRecipe (this );
2647+ Copy->transferFlags (*this );
2648+ return Copy;
2649+ }
2650+
2651+ VP_CLASSOF_IMPL (VPDef::VPExtendedReductionSC);
2652+
2653+ void execute (VPTransformState &State) override {
2654+ llvm_unreachable (" VPExtendedReductionRecipe should be transform to "
2655+ " VPExtendedRecipe + VPReductionRecipe before execution." );
2656+ };
2657+
2658+ // / Return the cost of VPExtendedReductionRecipe.
2659+ InstructionCost computeCost (ElementCount VF,
2660+ VPCostContext &Ctx) const override ;
2661+
2662+ #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2663+ // / Print the recipe.
2664+ void print (raw_ostream &O, const Twine &Indent,
2665+ VPSlotTracker &SlotTracker) const override ;
2666+ #endif
2667+
2668+ // / The scalar type after extending.
2669+ Type *getResultType () const { return ResultTy; }
2670+
2671+ // / Is the extend ZExt?
2672+ bool isZExt () const { return getExtOpcode () == Instruction::ZExt; }
2673+
2674+ // / The opcode of extend recipe.
2675+ Instruction::CastOps getExtOpcode () const { return ExtOp; }
2676+ };
2677+
2678+ // / A recipe to represent inloop MulAccumulateReduction operations, performing a
2679+ // / reduction.add on the result of vector operands (might be extended)
2680+ // / multiplication into a scalar value, and adding the result to a chain. This
2681+ // / recipe is abstract and needs to be lowered to concrete recipes before
2682+ // / codegen. The operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
2683+ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
2684+ // / Opcode of the extend recipe.
2685+ Instruction::CastOps ExtOp;
2686+
2687+ // / Non-neg flag of the extend recipe.
2688+ bool IsNonNeg = false ;
2689+
2690+ Type *ResultTy;
2691+
2692+ // / For cloning VPMulAccumulateReductionRecipe.
2693+ VPMulAccumulateReductionRecipe (VPMulAccumulateReductionRecipe *MulAcc)
2694+ : VPReductionRecipe(
2695+ VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind (),
2696+ {MulAcc->getChainOp (), MulAcc->getVecOp0 (), MulAcc->getVecOp1 ()},
2697+ MulAcc->getCondOp (), MulAcc->isOrdered(),
2698+ WrapFlagsTy(MulAcc->hasNoUnsignedWrap (), MulAcc->hasNoSignedWrap()),
2699+ MulAcc->getDebugLoc()),
2700+ ExtOp(MulAcc->getExtOpcode ()), IsNonNeg(MulAcc->isNonNeg ()),
2701+ ResultTy(MulAcc->getResultType ()) {}
2702+
2703+ public:
2704+ VPMulAccumulateReductionRecipe (VPReductionRecipe *R, VPWidenRecipe *Mul,
2705+ VPWidenCastRecipe *Ext0,
2706+ VPWidenCastRecipe *Ext1, Type *ResultTy)
2707+ : VPReductionRecipe(
2708+ VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind (),
2709+ {R->getChainOp (), Ext0->getOperand (0 ), Ext1->getOperand (0 )},
2710+ R->getCondOp (), R->isOrdered(),
2711+ WrapFlagsTy(Mul->hasNoUnsignedWrap (), Mul->hasNoSignedWrap()),
2712+ R->getDebugLoc()),
2713+ ExtOp(Ext0->getOpcode ()), ResultTy(ResultTy) {
2714+ assert (RecurrenceDescriptor::getOpcode (getRecurrenceKind ()) ==
2715+ Instruction::Add &&
2716+ " The reduction instruction in MulAccumulateteReductionRecipe must "
2717+ " be Add" );
2718+ // Only set the non-negative flag if the original recipe contains.
2719+ if (Ext0->hasNonNegFlag ())
2720+ IsNonNeg = Ext0->isNonNeg ();
2721+ }
2722+
2723+ VPMulAccumulateReductionRecipe (VPReductionRecipe *R, VPWidenRecipe *Mul)
2724+ : VPReductionRecipe(
2725+ VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind (),
2726+ {R->getChainOp (), Mul->getOperand (0 ), Mul->getOperand (1 )},
2727+ R->getCondOp (), R->isOrdered(),
2728+ WrapFlagsTy(Mul->hasNoUnsignedWrap (), Mul->hasNoSignedWrap()),
2729+ R->getDebugLoc()),
2730+ ExtOp(Instruction::CastOps::CastOpsEnd) {
2731+ assert (RecurrenceDescriptor::getOpcode (getRecurrenceKind ()) ==
2732+ Instruction::Add &&
2733+ " The reduction instruction in MulAccumulateReductionRecipe must be "
2734+ " Add" );
2735+ }
2736+
2737+ ~VPMulAccumulateReductionRecipe () override = default ;
2738+
2739+ VPMulAccumulateReductionRecipe *clone () override {
2740+ auto *Copy = new VPMulAccumulateReductionRecipe (this );
2741+ Copy->transferFlags (*this );
2742+ return Copy;
2743+ }
2744+
2745+ VP_CLASSOF_IMPL (VPDef::VPMulAccumulateReductionSC);
2746+
2747+ void execute (VPTransformState &State) override {
2748+ llvm_unreachable (" VPMulAccumulateReductionRecipe should transform to "
2749+ " VPWidenCastRecipe + "
2750+ " VPWidenRecipe + VPReductionRecipe before execution" );
2751+ }
2752+
2753+ // / Return the cost of VPMulAccumulateReductionRecipe.
2754+ InstructionCost computeCost (ElementCount VF,
2755+ VPCostContext &Ctx) const override ;
2756+
2757+ #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2758+ // / Print the recipe.
2759+ void print (raw_ostream &O, const Twine &Indent,
2760+ VPSlotTracker &SlotTracker) const override ;
2761+ #endif
2762+
2763+ Type *getResultType () const {
2764+ assert (isExtended () && " Only support getResultType when this recipe "
2765+ " contains implicit extend." );
2766+ return ResultTy;
2767+ }
2768+
2769+ // / The VPValue of the vector value to be extended and reduced.
2770+ VPValue *getVecOp0 () const { return getOperand (1 ); }
2771+ VPValue *getVecOp1 () const { return getOperand (2 ); }
2772+
2773+ // / Return if this MulAcc recipe contains extended operands.
2774+ bool isExtended () const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
2775+
2776+ // / Return the opcode of the extends for the operands.
2777+ Instruction::CastOps getExtOpcode () const { return ExtOp; }
2778+
2779+ // / Return if the operands are zero extended.
2780+ bool isZExt () const { return ExtOp == Instruction::CastOps::ZExt; }
2781+
2782+ // / Return the non negative flag of the ext recipe.
2783+ bool isNonNeg () const { return IsNonNeg; }
2784+ };
2785+
25412786// / VPReplicateRecipe replicates a given instruction producing multiple scalar
25422787// / copies of the original scalar type, one per lane, instead of producing a
25432788// / single copy of widened type for all lanes. If the instruction is known to be
0 commit comments