@@ -577,8 +577,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
577577#endif
578578};
579579
580- // / Class to record LLVM IR flag for a recipe along with it .
581- class VPRecipeWithIRFlags : public VPSingleDefRecipe {
580+ // / Class to record and manage LLVM IR flags .
581+ class VPIRFlags {
582582 enum class OperationType : unsigned char {
583583 Cmp,
584584 OverflowingBinOp,
@@ -637,23 +637,10 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
637637 unsigned AllFlags;
638638 };
639639
640- protected:
641- void transferFlags (VPRecipeWithIRFlags &Other) {
642- OpType = Other.OpType ;
643- AllFlags = Other.AllFlags ;
644- }
645-
646640public:
647- VPRecipeWithIRFlags (const unsigned char SC, ArrayRef<VPValue *> Operands,
648- DebugLoc DL = {})
649- : VPSingleDefRecipe(SC, Operands, DL) {
650- OpType = OperationType::Other;
651- AllFlags = 0 ;
652- }
641+ VPIRFlags () : OpType(OperationType::Other), AllFlags(0 ) {}
653642
654- VPRecipeWithIRFlags (const unsigned char SC, ArrayRef<VPValue *> Operands,
655- Instruction &I)
656- : VPSingleDefRecipe(SC, Operands, &I, I.getDebugLoc()) {
643+ VPIRFlags (Instruction &I) {
657644 if (auto *Op = dyn_cast<CmpInst>(&I)) {
658645 OpType = OperationType::Cmp;
659646 CmpPredicate = Op->getPredicate ();
@@ -681,63 +668,27 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
681668 }
682669 }
683670
684- VPRecipeWithIRFlags (const unsigned char SC, ArrayRef<VPValue *> Operands,
685- CmpInst::Predicate Pred, DebugLoc DL = {})
686- : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::Cmp),
687- CmpPredicate (Pred) {}
671+ VPIRFlags (CmpInst::Predicate Pred)
672+ : OpType(OperationType::Cmp), CmpPredicate(Pred) {}
688673
689- VPRecipeWithIRFlags (const unsigned char SC, ArrayRef<VPValue *> Operands,
690- WrapFlagsTy WrapFlags, DebugLoc DL = {})
691- : VPSingleDefRecipe(SC, Operands, DL),
692- OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
674+ VPIRFlags (WrapFlagsTy WrapFlags)
675+ : OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
693676
694- VPRecipeWithIRFlags (const unsigned char SC, ArrayRef<VPValue *> Operands,
695- FastMathFlags FMFs, DebugLoc DL = {})
696- : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::FPMathOp),
697- FMFs(FMFs) {}
677+ VPIRFlags (FastMathFlags FMFs) : OpType(OperationType::FPMathOp), FMFs(FMFs) {}
698678
699- VPRecipeWithIRFlags (const unsigned char SC, ArrayRef<VPValue *> Operands,
700- DisjointFlagsTy DisjointFlags, DebugLoc DL = {})
701- : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
702- DisjointFlags(DisjointFlags) {}
679+ VPIRFlags (DisjointFlagsTy DisjointFlags)
680+ : OpType(OperationType::DisjointOp), DisjointFlags(DisjointFlags) {}
703681
704- template <typename IterT>
705- VPRecipeWithIRFlags (const unsigned char SC, IterT Operands,
706- NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
707- : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
708- NonNegFlags(NonNegFlags) {}
682+ VPIRFlags (NonNegFlagsTy NonNegFlags)
683+ : OpType(OperationType::NonNegOp), NonNegFlags(NonNegFlags) {}
709684
710- protected:
711- VPRecipeWithIRFlags (const unsigned char SC, ArrayRef<VPValue *> Operands,
712- GEPNoWrapFlags GEPFlags, DebugLoc DL = {})
713- : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::GEPOp),
714- GEPFlags(GEPFlags) {}
685+ VPIRFlags (GEPNoWrapFlags GEPFlags)
686+ : OpType(OperationType::GEPOp), GEPFlags(GEPFlags) {}
715687
716688public:
717- static inline bool classof (const VPRecipeBase *R) {
718- return R->getVPDefID () == VPRecipeBase::VPInstructionSC ||
719- R->getVPDefID () == VPRecipeBase::VPWidenSC ||
720- R->getVPDefID () == VPRecipeBase::VPWidenGEPSC ||
721- R->getVPDefID () == VPRecipeBase::VPWidenCallSC ||
722- R->getVPDefID () == VPRecipeBase::VPWidenCastSC ||
723- R->getVPDefID () == VPRecipeBase::VPWidenIntrinsicSC ||
724- R->getVPDefID () == VPRecipeBase::VPReductionSC ||
725- R->getVPDefID () == VPRecipeBase::VPReductionEVLSC ||
726- R->getVPDefID () == VPRecipeBase::VPReplicateSC ||
727- R->getVPDefID () == VPRecipeBase::VPVectorEndPointerSC ||
728- R->getVPDefID () == VPRecipeBase::VPVectorPointerSC ||
729- R->getVPDefID () == VPRecipeBase::VPExtendedReductionSC ||
730- R->getVPDefID () == VPRecipeBase::VPMulAccumulateReductionSC;
731- }
732-
733- static inline bool classof (const VPUser *U) {
734- auto *R = dyn_cast<VPRecipeBase>(U);
735- return R && classof (R);
736- }
737-
738- static inline bool classof (const VPValue *V) {
739- auto *R = dyn_cast_or_null<VPRecipeBase>(V->getDefiningRecipe ());
740- return R && classof (R);
689+ void transferFlags (VPIRFlags &Other) {
690+ OpType = Other.OpType ;
691+ AllFlags = Other.AllFlags ;
741692 }
742693
743694 // / Drop all poison-generating flags.
@@ -851,11 +802,60 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
851802 return DisjointFlags.IsDisjoint ;
852803 }
853804
805+ #if !defined(NDEBUG)
806+ // / Returns true if the set flags are valid for \p Opcode.
807+ bool flagsValidForOpcode (unsigned Opcode) const ;
808+ #endif
809+
854810#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
855811 void printFlags (raw_ostream &O) const ;
856812#endif
857813};
858814
815+ // / A pure-virtual common base class for recipes defining a single VPValue and
816+ // / using IR flags.
817+ struct VPRecipeWithIRFlags : public VPSingleDefRecipe , public VPIRFlags {
818+ VPRecipeWithIRFlags (const unsigned char SC, ArrayRef<VPValue *> Operands,
819+ DebugLoc DL = {})
820+ : VPSingleDefRecipe(SC, Operands, DL), VPIRFlags() {}
821+
822+ VPRecipeWithIRFlags (const unsigned char SC, ArrayRef<VPValue *> Operands,
823+ Instruction &I)
824+ : VPSingleDefRecipe(SC, Operands, &I, I.getDebugLoc()), VPIRFlags(I) {}
825+
826+ VPRecipeWithIRFlags (const unsigned char SC, ArrayRef<VPValue *> Operands,
827+ const VPIRFlags &Flags, DebugLoc DL = {})
828+ : VPSingleDefRecipe(SC, Operands, DL), VPIRFlags(Flags) {}
829+
830+ static inline bool classof (const VPRecipeBase *R) {
831+ return R->getVPDefID () == VPRecipeBase::VPInstructionSC ||
832+ R->getVPDefID () == VPRecipeBase::VPWidenSC ||
833+ R->getVPDefID () == VPRecipeBase::VPWidenGEPSC ||
834+ R->getVPDefID () == VPRecipeBase::VPWidenCallSC ||
835+ R->getVPDefID () == VPRecipeBase::VPWidenCastSC ||
836+ R->getVPDefID () == VPRecipeBase::VPWidenIntrinsicSC ||
837+ R->getVPDefID () == VPRecipeBase::VPReductionSC ||
838+ R->getVPDefID () == VPRecipeBase::VPReductionEVLSC ||
839+ R->getVPDefID () == VPRecipeBase::VPReplicateSC ||
840+ R->getVPDefID () == VPRecipeBase::VPVectorEndPointerSC ||
841+ R->getVPDefID () == VPRecipeBase::VPVectorPointerSC ||
842+ R->getVPDefID () == VPRecipeBase::VPExtendedReductionSC ||
843+ R->getVPDefID () == VPRecipeBase::VPMulAccumulateReductionSC;
844+ }
845+
846+ static inline bool classof (const VPUser *U) {
847+ auto *R = dyn_cast<VPRecipeBase>(U);
848+ return R && classof (R);
849+ }
850+
851+ static inline bool classof (const VPValue *V) {
852+ auto *R = dyn_cast_or_null<VPRecipeBase>(V->getDefiningRecipe ());
853+ return R && classof (R);
854+ }
855+
856+ void execute (VPTransformState &State) override = 0;
857+ };
858+
859859// / Helper to access the operand that contains the unroll part for this recipe
860860// / after unrolling.
861861template <unsigned PartOpIdx> class VPUnrollPartAccessor {
@@ -958,54 +958,21 @@ class VPInstruction : public VPRecipeWithIRFlags,
958958 // / value for lane \p Lane.
959959 Value *generatePerLane (VPTransformState &State, const VPLane &Lane);
960960
961- #if !defined(NDEBUG)
962- // / Return true if the VPInstruction is a floating point math operation, i.e.
963- // / has fast-math flags.
964- bool isFPMathOp () const ;
965- #endif
966-
967961public:
968- VPInstruction (unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL,
962+ VPInstruction (unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL = {} ,
969963 const Twine &Name = " " )
970964 : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DL),
971965 Opcode (Opcode), Name(Name.str()) {}
972966
973- VPInstruction (unsigned Opcode, std::initializer_list<VPValue *> Operands,
974- DebugLoc DL = {}, const Twine &Name = " " )
975- : VPInstruction(Opcode, ArrayRef<VPValue *>(Operands), DL, Name) {}
976-
977- VPInstruction (unsigned Opcode, CmpInst::Predicate Pred, VPValue *A,
978- VPValue *B, DebugLoc DL = {}, const Twine &Name = " " );
979-
980- VPInstruction (unsigned Opcode, std::initializer_list<VPValue *> Operands,
981- WrapFlagsTy WrapFlags, DebugLoc DL = {}, const Twine &Name = " " )
982- : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, WrapFlags, DL),
983- Opcode(Opcode), Name(Name.str()) {}
984-
985- VPInstruction (unsigned Opcode, std::initializer_list<VPValue *> Operands,
986- DisjointFlagsTy DisjointFlag, DebugLoc DL = {},
987- const Twine &Name = " " )
988- : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DisjointFlag, DL),
989- Opcode(Opcode), Name(Name.str()) {
990- assert (Opcode == Instruction::Or && " only OR opcodes can be disjoint" );
991- }
992-
993- VPInstruction (VPValue *Ptr, VPValue *Offset, GEPNoWrapFlags Flags,
994- DebugLoc DL = {}, const Twine &Name = " " )
995- : VPRecipeWithIRFlags(VPDef::VPInstructionSC,
996- ArrayRef<VPValue *>({Ptr, Offset}), Flags, DL),
997- Opcode(VPInstruction::PtrAdd), Name(Name.str()) {}
998-
999- VPInstruction (unsigned Opcode, std::initializer_list<VPValue *> Operands,
1000- FastMathFlags FMFs, DebugLoc DL = {}, const Twine &Name = " " );
967+ VPInstruction (unsigned Opcode, ArrayRef<VPValue *> Operands,
968+ const VPIRFlags &Flags, DebugLoc DL = {},
969+ const Twine &Name = " " );
1001970
1002971 VP_CLASSOF_IMPL (VPDef::VPInstructionSC)
1003972
1004973 VPInstruction *clone() override {
1005974 SmallVector<VPValue *, 2 > Operands (operands ());
1006- auto *New = new VPInstruction (Opcode, Operands, getDebugLoc (), Name);
1007- New->transferFlags (*this );
1008- return New;
975+ return new VPInstruction (Opcode, Operands, *this , getDebugLoc (), Name);
1009976 }
1010977
1011978 unsigned getOpcode () const { return Opcode; }
@@ -1082,13 +1049,9 @@ class VPInstructionWithType : public VPInstruction {
10821049
10831050public:
10841051 VPInstructionWithType (unsigned Opcode, ArrayRef<VPValue *> Operands,
1085- Type *ResultTy, DebugLoc DL, const Twine &Name = " " )
1086- : VPInstruction(Opcode, Operands, DL, Name), ResultTy(ResultTy) {}
1087- VPInstructionWithType (unsigned Opcode,
1088- std::initializer_list<VPValue *> Operands,
1089- Type *ResultTy, FastMathFlags FMFs, DebugLoc DL = {},
1052+ Type *ResultTy, const VPIRFlags &Flags, DebugLoc DL,
10901053 const Twine &Name = " " )
1091- : VPInstruction(Opcode, Operands, FMFs , DL, Name), ResultTy(ResultTy) {}
1054+ : VPInstruction(Opcode, Operands, Flags , DL, Name), ResultTy(ResultTy) {}
10921055
10931056 static inline bool classof (const VPRecipeBase *R) {
10941057 // VPInstructionWithType are VPInstructions with specific opcodes requiring
@@ -1113,8 +1076,9 @@ class VPInstructionWithType : public VPInstruction {
11131076
11141077 VPInstruction *clone () override {
11151078 SmallVector<VPValue *, 2 > Operands (operands ());
1116- auto *New = new VPInstructionWithType (
1117- getOpcode (), Operands, getResultType (), getDebugLoc (), getName ());
1079+ auto *New =
1080+ new VPInstructionWithType (getOpcode (), Operands, getResultType (), *this ,
1081+ getDebugLoc (), getName ());
11181082 New->setUnderlyingValue (getUnderlyingValue ());
11191083 return New;
11201084 }
@@ -1373,15 +1337,12 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
13731337 }
13741338
13751339 VPWidenCastRecipe (Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1376- DebugLoc DL = {})
1377- : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), VPIRMetadata(),
1378- Opcode(Opcode), ResultTy(ResultTy) {}
1379-
1380- VPWidenCastRecipe (Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1381- bool IsNonNeg, DebugLoc DL = {})
1382- : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
1383- DL),
1384- Opcode(Opcode), ResultTy(ResultTy) {}
1340+ const VPIRFlags &Flags = {}, DebugLoc DL = {})
1341+ : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, Flags, DL),
1342+ VPIRMetadata(), Opcode(Opcode), ResultTy(ResultTy) {
1343+ assert (flagsValidForOpcode (Opcode) &&
1344+ " Set flags not supported for the provided opcode" );
1345+ }
13851346
13861347 ~VPWidenCastRecipe () override = default ;
13871348
0 commit comments