Skip to content

Commit e3df744

Browse files
ElvisWang123lukel97
authored andcommitted
[VPlan] Change parent of VPReductionRecipe to VPRecipeWithIRFlags. NFC
This patch change the parent of the VPReductionRecipe from VPSingleDefRecipe to VPRecipeWithIRFlags and also print/get/control flags by the VPRecipeWithIRFlags. This will remove the dependency of the underlying instruction. This patch also add a new function `setFastMathFlags()` to the VPRecipeWithIRFlags because the entire reduction chain may contains multiple instructions. And the underlying instruction may not contains the corresponding flags for this reduction.
1 parent 4ac2a49 commit e3df744

File tree

3 files changed

+36
-21
lines changed

3 files changed

+36
-21
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9797,9 +9797,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
97979797
if (CM.blockNeedsPredicationForAnyReason(BB))
97989798
CondOp = RecipeBuilder.getBlockInMask(BB);
97999799

9800-
auto *RedRecipe = new VPReductionRecipe(
9801-
RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp,
9802-
CM.useOrderedReductions(RdxDesc), CurrentLinkI->getDebugLoc());
9800+
auto *RedRecipe =
9801+
new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
9802+
CondOp, CM.useOrderedReductions(RdxDesc));
98039803
// Append the recipe to the end of the VPBasicBlock because we need to
98049804
// ensure that it comes after all of it's inputs, including CondOp.
98059805
// Delete CurrentLink as it will be invalid if its operand is replaced

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,8 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
711711
R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
712712
R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
713713
R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
714+
R->getVPDefID() == VPRecipeBase::VPReductionSC ||
715+
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
714716
R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
715717
R->getVPDefID() == VPRecipeBase::VPReverseVectorPointerSC ||
716718
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
@@ -786,6 +788,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
786788
}
787789
}
788790

791+
/// Set fast-math flags for this recipe.
792+
void setFastMathFlags(FastMathFlags FMFs) {
793+
OpType = OperationType::FPMathOp;
794+
this->FMFs = FMFs;
795+
}
796+
789797
CmpInst::Predicate getPredicate() const {
790798
assert(OpType == OperationType::Cmp &&
791799
"recipe doesn't have a compare predicate");
@@ -2236,7 +2244,7 @@ class VPInterleaveRecipe : public VPRecipeBase {
22362244
/// A recipe to represent inloop reduction operations, performing a reduction on
22372245
/// a vector operand into a scalar value, and adding the result to a chain.
22382246
/// The Operands are {ChainOp, VecOp, [Condition]}.
2239-
class VPReductionRecipe : public VPSingleDefRecipe {
2247+
class VPReductionRecipe : public VPRecipeWithIRFlags {
22402248
/// The recurrence decriptor for the reduction in question.
22412249
const RecurrenceDescriptor &RdxDesc;
22422250
bool IsOrdered;
@@ -2246,29 +2254,32 @@ class VPReductionRecipe : public VPSingleDefRecipe {
22462254
protected:
22472255
VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
22482256
Instruction *I, ArrayRef<VPValue *> Operands,
2249-
VPValue *CondOp, bool IsOrdered, DebugLoc DL)
2250-
: VPSingleDefRecipe(SC, Operands, I, DL), RdxDesc(R),
2257+
VPValue *CondOp, bool IsOrdered)
2258+
: VPRecipeWithIRFlags(SC, Operands, *I), RdxDesc(R),
22512259
IsOrdered(IsOrdered) {
22522260
if (CondOp) {
22532261
IsConditional = true;
22542262
addOperand(CondOp);
22552263
}
2264+
// The inloop reduction may across multiple scalar instruction and the
2265+
// underlying instruction may not contains the corresponding flags. Set the
2266+
// flags explicit from the redurrence descriptor.
2267+
setFastMathFlags(R.getFastMathFlags());
22562268
}
22572269

22582270
public:
22592271
VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
22602272
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
2261-
bool IsOrdered, DebugLoc DL = {})
2273+
bool IsOrdered)
22622274
: VPReductionRecipe(VPDef::VPReductionSC, R, I,
22632275
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
2264-
IsOrdered, DL) {}
2276+
IsOrdered) {}
22652277

22662278
~VPReductionRecipe() override = default;
22672279

22682280
VPReductionRecipe *clone() override {
22692281
return new VPReductionRecipe(RdxDesc, getUnderlyingInstr(), getChainOp(),
2270-
getVecOp(), getCondOp(), IsOrdered,
2271-
getDebugLoc());
2282+
getVecOp(), getCondOp(), IsOrdered);
22722283
}
22732284

22742285
static inline bool classof(const VPRecipeBase *R) {
@@ -2323,7 +2334,7 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
23232334
VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
23242335
cast_or_null<Instruction>(R.getUnderlyingValue()),
23252336
ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
2326-
R.isOrdered(), R.getDebugLoc()) {}
2337+
R.isOrdered()) {}
23272338

23282339
~VPReductionEVLRecipe() override = default;
23292340

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2290,7 +2290,8 @@ void VPReductionRecipe::execute(VPTransformState &State) {
22902290
"In-loop AnyOf reductions aren't currently supported");
22912291
// Propagate the fast-math flags carried by the underlying instruction.
22922292
IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
2293-
State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
2293+
if (hasFastMathFlags())
2294+
State.Builder.setFastMathFlags(getFastMathFlags());
22942295
State.setDebugLocFrom(getDebugLoc());
22952296
Value *NewVecOp = State.get(getVecOp());
22962297
if (VPValue *Cond = getCondOp()) {
@@ -2337,7 +2338,8 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
23372338
// Propagate the fast-math flags carried by the underlying instruction.
23382339
IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
23392340
const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
2340-
Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
2341+
if (hasFastMathFlags())
2342+
Builder.setFastMathFlags(getFastMathFlags());
23412343

23422344
RecurKind Kind = RdxDesc.getRecurrenceKind();
23432345
Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
@@ -2374,6 +2376,8 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23742376
Type *ElementTy = Ctx.Types.inferScalarType(this);
23752377
auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
23762378
unsigned Opcode = RdxDesc.getOpcode();
2379+
FastMathFlags FMFs =
2380+
hasFastMathFlags() ? getFastMathFlags() : FastMathFlags();
23772381

23782382
// TODO: Support any-of and in-loop reductions.
23792383
assert(
@@ -2393,12 +2397,12 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23932397
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
23942398
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
23952399
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
2396-
return Cost + Ctx.TTI.getMinMaxReductionCost(
2397-
Id, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
2400+
return Cost +
2401+
Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
23982402
}
23992403

2400-
return Cost + Ctx.TTI.getArithmeticReductionCost(
2401-
Opcode, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
2404+
return Cost + Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
2405+
Ctx.CostKind);
24022406
}
24032407

24042408
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2409,8 +2413,8 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
24092413
O << " = ";
24102414
getChainOp()->printAsOperand(O, SlotTracker);
24112415
O << " +";
2412-
if (isa<FPMathOperator>(getUnderlyingInstr()))
2413-
O << getUnderlyingInstr()->getFastMathFlags();
2416+
if (isa_and_nonnull<FPMathOperator>(getUnderlyingValue()))
2417+
printFlags(O);
24142418
O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
24152419
getVecOp()->printAsOperand(O, SlotTracker);
24162420
if (isConditional()) {
@@ -2431,8 +2435,8 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
24312435
O << " = ";
24322436
getChainOp()->printAsOperand(O, SlotTracker);
24332437
O << " +";
2434-
if (isa<FPMathOperator>(getUnderlyingInstr()))
2435-
O << getUnderlyingInstr()->getFastMathFlags();
2438+
if (isa_and_nonnull<FPMathOperator>(getUnderlyingValue()))
2439+
printFlags(O);
24362440
O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
24372441
getVecOp()->printAsOperand(O, SlotTracker);
24382442
O << ", ";

0 commit comments

Comments
 (0)