Skip to content

Commit 1e3ea03

Browse files
committed
[VPlan] VPIRFlags kind for FCmp with predicate + fast-math flags (NFCI).
FCmp instructions have both a predicate and fast-math flags. Introduce a new FCmp kind, that combines both to model this correctly in the current system. This should be NFC modulo VPlan printing which now includes the correct fast-math flags.
1 parent 5cde345 commit 1e3ea03

File tree

4 files changed

+84
-36
lines changed

4 files changed

+84
-36
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
609609
class VPIRFlags {
610610
enum class OperationType : unsigned char {
611611
Cmp,
612+
FCmp,
612613
OverflowingBinOp,
613614
Trunc,
614615
DisjointOp,
@@ -659,6 +660,12 @@ class VPIRFlags {
659660

660661
LLVM_ABI_FOR_TEST FastMathFlagsTy(const FastMathFlags &FMF);
661662
};
663+
/// Holds both the predicate and fast-math flags for floating-point
664+
/// comparisons.
665+
struct FCmpFlagsTy {
666+
CmpInst::Predicate Pred;
667+
FastMathFlagsTy FMFs;
668+
};
662669

663670
OperationType OpType;
664671

@@ -671,14 +678,19 @@ class VPIRFlags {
671678
GEPNoWrapFlags GEPFlags;
672679
NonNegFlagsTy NonNegFlags;
673680
FastMathFlagsTy FMFs;
681+
FCmpFlagsTy FCmpFlags;
674682
unsigned AllFlags;
675683
};
676684

677685
public:
678686
VPIRFlags() : OpType(OperationType::Other), AllFlags(0) {}
679687

680688
VPIRFlags(Instruction &I) {
681-
if (auto *Op = dyn_cast<CmpInst>(&I)) {
689+
if (auto *FCmp = dyn_cast<FCmpInst>(&I)) {
690+
OpType = OperationType::FCmp;
691+
FCmpFlags.Pred = FCmp->getPredicate();
692+
FCmpFlags.FMFs = FCmp->getFastMathFlags();
693+
} else if (auto *Op = dyn_cast<CmpInst>(&I)) {
682694
OpType = OperationType::Cmp;
683695
CmpPredicate = Op->getPredicate();
684696
} else if (auto *Op = dyn_cast<PossiblyDisjointInst>(&I)) {
@@ -711,6 +723,12 @@ class VPIRFlags {
711723
VPIRFlags(CmpInst::Predicate Pred)
712724
: OpType(OperationType::Cmp), CmpPredicate(Pred) {}
713725

726+
VPIRFlags(CmpInst::Predicate Pred, FastMathFlags FMFs)
727+
: OpType(OperationType::FCmp) {
728+
FCmpFlags.Pred = Pred;
729+
FCmpFlags.FMFs = FMFs;
730+
}
731+
714732
VPIRFlags(WrapFlagsTy WrapFlags)
715733
: OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
716734

@@ -760,8 +778,9 @@ class VPIRFlags {
760778
GEPFlags = GEPNoWrapFlags::none();
761779
break;
762780
case OperationType::FPMathOp:
763-
FMFs.NoNaNs = false;
764-
FMFs.NoInfs = false;
781+
case OperationType::FCmp:
782+
getFMFsRef().NoNaNs = false;
783+
getFMFsRef().NoInfs = false;
765784
break;
766785
case OperationType::NonNegOp:
767786
NonNegFlags.NonNeg = false;
@@ -793,14 +812,17 @@ class VPIRFlags {
793812
cast<GetElementPtrInst>(&I)->setNoWrapFlags(GEPFlags);
794813
break;
795814
case OperationType::FPMathOp:
796-
I.setHasAllowReassoc(FMFs.AllowReassoc);
797-
I.setHasNoNaNs(FMFs.NoNaNs);
798-
I.setHasNoInfs(FMFs.NoInfs);
799-
I.setHasNoSignedZeros(FMFs.NoSignedZeros);
800-
I.setHasAllowReciprocal(FMFs.AllowReciprocal);
801-
I.setHasAllowContract(FMFs.AllowContract);
802-
I.setHasApproxFunc(FMFs.ApproxFunc);
815+
case OperationType::FCmp: {
816+
const FastMathFlagsTy &F = getFMFsRef();
817+
I.setHasAllowReassoc(F.AllowReassoc);
818+
I.setHasNoNaNs(F.NoNaNs);
819+
I.setHasNoInfs(F.NoInfs);
820+
I.setHasNoSignedZeros(F.NoSignedZeros);
821+
I.setHasAllowReciprocal(F.AllowReciprocal);
822+
I.setHasAllowContract(F.AllowContract);
823+
I.setHasApproxFunc(F.ApproxFunc);
803824
break;
825+
}
804826
case OperationType::NonNegOp:
805827
I.setNonNeg(NonNegFlags.NonNeg);
806828
break;
@@ -811,24 +833,31 @@ class VPIRFlags {
811833
}
812834

813835
CmpInst::Predicate getPredicate() const {
814-
assert(OpType == OperationType::Cmp &&
836+
assert((OpType == OperationType::Cmp || OpType == OperationType::FCmp) &&
815837
"recipe doesn't have a compare predicate");
816-
return CmpPredicate;
838+
return OpType == OperationType::FCmp ? FCmpFlags.Pred : CmpPredicate;
817839
}
818840

819841
void setPredicate(CmpInst::Predicate Pred) {
820-
assert(OpType == OperationType::Cmp &&
842+
assert((OpType == OperationType::Cmp || OpType == OperationType::FCmp) &&
821843
"recipe doesn't have a compare predicate");
822-
CmpPredicate = Pred;
844+
if (OpType == OperationType::FCmp)
845+
FCmpFlags.Pred = Pred;
846+
else
847+
CmpPredicate = Pred;
823848
}
824849

825850
GEPNoWrapFlags getGEPNoWrapFlags() const { return GEPFlags; }
826851

827852
/// Returns true if the recipe has a comparison predicate.
828-
bool hasPredicate() const { return OpType == OperationType::Cmp; }
853+
bool hasPredicate() const {
854+
return OpType == OperationType::Cmp || OpType == OperationType::FCmp;
855+
}
829856

830857
/// Returns true if the recipe has fast-math flags.
831-
bool hasFastMathFlags() const { return OpType == OperationType::FPMathOp; }
858+
bool hasFastMathFlags() const {
859+
return OpType == OperationType::FPMathOp || OpType == OperationType::FCmp;
860+
}
832861

833862
LLVM_ABI_FOR_TEST FastMathFlags getFastMathFlags() const;
834863

@@ -869,6 +898,16 @@ class VPIRFlags {
869898
return DisjointFlags.IsDisjoint;
870899
}
871900

901+
private:
902+
/// Get a reference to the fast-math flags for FPMathOp or FCmp.
903+
FastMathFlagsTy &getFMFsRef() {
904+
return OpType == OperationType::FCmp ? FCmpFlags.FMFs : FMFs;
905+
}
906+
const FastMathFlagsTy &getFMFsRef() const {
907+
return OpType == OperationType::FCmp ? FCmpFlags.FMFs : FMFs;
908+
}
909+
910+
public:
872911
#if !defined(NDEBUG)
873912
/// Returns true if the set flags are valid for \p Opcode.
874913
bool flagsValidForOpcode(unsigned Opcode) const;

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -453,8 +453,12 @@ void VPIRFlags::intersectFlags(const VPIRFlags &Other) {
453453
GEPFlags &= Other.GEPFlags;
454454
break;
455455
case OperationType::FPMathOp:
456-
FMFs.NoNaNs &= Other.FMFs.NoNaNs;
457-
FMFs.NoInfs &= Other.FMFs.NoInfs;
456+
case OperationType::FCmp:
457+
assert((OpType != OperationType::FCmp ||
458+
FCmpFlags.Pred == Other.FCmpFlags.Pred) &&
459+
"Cannot drop CmpPredicate");
460+
getFMFsRef().NoNaNs &= Other.getFMFsRef().NoNaNs;
461+
getFMFsRef().NoInfs &= Other.getFMFsRef().NoInfs;
458462
break;
459463
case OperationType::NonNegOp:
460464
NonNegFlags.NonNeg &= Other.NonNegFlags.NonNeg;
@@ -469,16 +473,17 @@ void VPIRFlags::intersectFlags(const VPIRFlags &Other) {
469473
}
470474

471475
FastMathFlags VPIRFlags::getFastMathFlags() const {
472-
assert(OpType == OperationType::FPMathOp &&
476+
assert((OpType == OperationType::FPMathOp || OpType == OperationType::FCmp) &&
473477
"recipe doesn't have fast math flags");
478+
const FastMathFlagsTy &F = getFMFsRef();
474479
FastMathFlags Res;
475-
Res.setAllowReassoc(FMFs.AllowReassoc);
476-
Res.setNoNaNs(FMFs.NoNaNs);
477-
Res.setNoInfs(FMFs.NoInfs);
478-
Res.setNoSignedZeros(FMFs.NoSignedZeros);
479-
Res.setAllowReciprocal(FMFs.AllowReciprocal);
480-
Res.setAllowContract(FMFs.AllowContract);
481-
Res.setApproxFunc(FMFs.ApproxFunc);
480+
Res.setAllowReassoc(F.AllowReassoc);
481+
Res.setNoNaNs(F.NoNaNs);
482+
Res.setNoInfs(F.NoInfs);
483+
Res.setNoSignedZeros(F.NoSignedZeros);
484+
Res.setAllowReciprocal(F.AllowReciprocal);
485+
Res.setAllowContract(F.AllowContract);
486+
Res.setApproxFunc(F.ApproxFunc);
482487
return Res;
483488
}
484489

@@ -2074,11 +2079,12 @@ bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
20742079
Opcode == Instruction::FMul || Opcode == Instruction::FSub ||
20752080
Opcode == Instruction::FNeg || Opcode == Instruction::FDiv ||
20762081
Opcode == Instruction::FRem || Opcode == Instruction::FPExt ||
2077-
Opcode == Instruction::FPTrunc || Opcode == Instruction::FCmp ||
2078-
Opcode == Instruction::Select ||
2082+
Opcode == Instruction::FPTrunc || Opcode == Instruction::Select ||
20792083
Opcode == VPInstruction::WideIVStep ||
20802084
Opcode == VPInstruction::ReductionStartVector ||
20812085
Opcode == VPInstruction::ComputeReductionResult;
2086+
case OperationType::FCmp:
2087+
return Opcode == Instruction::FCmp;
20822088
case OperationType::NonNegOp:
20832089
return Opcode == Instruction::ZExt || Opcode == Instruction::UIToFP;
20842090
case OperationType::Cmp:
@@ -2096,6 +2102,10 @@ void VPIRFlags::printFlags(raw_ostream &O) const {
20962102
case OperationType::Cmp:
20972103
O << " " << CmpInst::getPredicateName(getPredicate());
20982104
break;
2105+
case OperationType::FCmp:
2106+
O << " " << CmpInst::getPredicateName(getPredicate());
2107+
getFastMathFlags().print(O);
2108+
break;
20992109
case OperationType::DisjointOp:
21002110
if (DisjointFlags.IsDisjoint)
21012111
O << " disjoint";
@@ -2204,15 +2214,14 @@ void VPWidenRecipe::execute(VPTransformState &State) {
22042214
Value *B = State.get(getOperand(1));
22052215
Value *C = nullptr;
22062216
if (FCmp) {
2207-
// Propagate fast math flags.
2208-
C = Builder.CreateFCmpFMF(
2209-
getPredicate(), A, B,
2210-
dyn_cast_or_null<Instruction>(getUnderlyingValue()));
2217+
C = Builder.CreateFCmp(getPredicate(), A, B);
22112218
} else {
22122219
C = Builder.CreateICmp(getPredicate(), A, B);
22132220
}
2214-
if (auto *I = dyn_cast<Instruction>(C))
2221+
if (auto *I = dyn_cast<Instruction>(C)) {
2222+
applyFlags(*I);
22152223
applyMetadata(*I);
2224+
}
22162225
State.set(this, C);
22172226
break;
22182227
}

llvm/test/Transforms/LoopVectorize/ARM/mve-icmpcost.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,8 @@ while.end: ; preds = %while.end.loopexit,
293293
}
294294

295295
; CHECK: LV: Found an estimated cost of 1 for VF 1 For instruction: %cmp1 = fcmp
296-
; CHECK: Cost of 12 for VF 2: WIDEN ir<%cmp1> = fcmp olt ir<%0>, ir<0.000000e+00>
297-
; CHECK: Cost of 24 for VF 4: WIDEN ir<%cmp1> = fcmp olt ir<%0>, ir<0.000000e+00>
296+
; CHECK: Cost of 12 for VF 2: WIDEN ir<%cmp1> = fcmp olt nnan ninf nsz ir<%0>, ir<0.000000e+00>
297+
; CHECK: Cost of 24 for VF 4: WIDEN ir<%cmp1> = fcmp olt nnan ninf nsz ir<%0>, ir<0.000000e+00>
298298
define void @floatcmp(ptr nocapture readonly %pSrc, ptr nocapture %pDst, i32 %blockSize) #0 {
299299
entry:
300300
%cmp.not7 = icmp eq i32 %blockSize, 0

llvm/test/Transforms/LoopVectorize/vplan-printing.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1035,7 +1035,7 @@ define void @print_select_with_fastmath_flags(ptr noalias %a, ptr noalias %b, pt
10351035
; CHECK-NEXT: CLONE ir<[[GEP2:%.+]]> = getelementptr inbounds nuw ir<%c>, vp<[[ST]]>
10361036
; CHECK-NEXT: vp<[[PTR2:%.+]]> = vector-pointer ir<[[GEP2]]>
10371037
; CHECK-NEXT: WIDEN ir<[[LD2:%.+]]> = load vp<[[PTR2]]>
1038-
; CHECK-NEXT: WIDEN ir<[[FCMP:%.+]]> = fcmp ogt ir<[[LD1]]>, ir<[[LD2]]>
1038+
; CHECK-NEXT: WIDEN ir<[[FCMP:%.+]]> = fcmp ogt fast ir<[[LD1]]>, ir<[[LD2]]>
10391039
; CHECK-NEXT: WIDEN ir<[[FADD:%.+]]> = fadd fast ir<[[LD1]]>, ir<1.000000e+01>
10401040
; CHECK-NEXT: WIDEN-SELECT ir<[[SELECT:%.+]]> = select fast ir<[[FCMP]]>, ir<[[FADD]]>, ir<[[LD2]]>
10411041
; CHECK-NEXT: CLONE ir<[[GEP3:%.+]]> = getelementptr inbounds nuw ir<%a>, vp<[[ST]]>

0 commit comments

Comments
 (0)