-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[LoopUtils] Consider fast-math flags derived from function attribute when getting recurrence identity #162630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…unction attributes. nfc
…when getting recurrence identity
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-transforms Author: Mel Chen (Mel-Chen) ChangesBased on #162628 Full diff: https://github.com/llvm/llvm-project/pull/162630.diff 5 Files Affected:
diff --git a/llvm/include/llvm/IR/Function.h b/llvm/include/llvm/IR/Function.h
index d3497716ca844..21e9d5ea9d066 100644
--- a/llvm/include/llvm/IR/Function.h
+++ b/llvm/include/llvm/IR/Function.h
@@ -470,6 +470,9 @@ class LLVM_ABI Function : public GlobalObject, public ilist_node<Function> {
return AttributeSets.getFnStackAlignment();
}
+ /// Derive fast-math flags from the function attributes.
+ FastMathFlags getFMFFromFnAttribute() const;
+
/// Returns true if the function has ssp, sspstrong, or sspreq fn attrs.
bool hasStackProtectorFnAttr() const;
diff --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp
index b8c540ce4b99d..9b842d6ab37c5 100644
--- a/llvm/lib/Analysis/IVDescriptors.cpp
+++ b/llvm/lib/Analysis/IVDescriptors.cpp
@@ -998,11 +998,7 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
ScalarEvolution *SE) {
BasicBlock *Header = TheLoop->getHeader();
Function &F = *Header->getParent();
- FastMathFlags FMF;
- FMF.setNoNaNs(
- F.getFnAttribute("no-nans-fp-math").getValueAsBool());
- FMF.setNoSignedZeros(
- F.getFnAttribute("no-signed-zeros-fp-math").getValueAsBool());
+ FastMathFlags FMF = F.getFMFFromFnAttribute();
if (AddReductionVar(Phi, RecurKind::Add, TheLoop, FMF, RedDes, DB, AC, DT,
SE)) {
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index fc067459dcba3..d2b41c1c9c2ab 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -784,6 +784,14 @@ uint64_t Function::getFnAttributeAsParsedInteger(StringRef Name,
return Result;
}
+FastMathFlags Function::getFMFFromFnAttribute() const {
+ FastMathFlags FuncFMF;
+ FuncFMF.setNoNaNs(getFnAttribute("no-nans-fp-math").getValueAsBool());
+ FuncFMF.setNoSignedZeros(
+ getFnAttribute("no-signed-zeros-fp-math").getValueAsBool());
+ return FuncFMF;
+}
+
/// gets the specified attribute from the list of attributes.
Attribute Function::getParamAttribute(unsigned ArgNo,
Attribute::AttrKind Kind) const {
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index b6ba82288aeb4..c001de4f1d2b5 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1398,8 +1398,10 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
RecurKind RdxKind) {
auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType();
auto getIdentity = [&]() {
+ Function *Fn = Builder.GetInsertBlock()->getParent();
return getRecurrenceIdentity(RdxKind, SrcVecEltTy,
- Builder.getFastMathFlags());
+ Builder.getFastMathFlags() |
+ Fn->getFMFFromFnAttribute());
};
switch (RdxKind) {
case RecurKind::AddChainWithSubs:
@@ -1442,7 +1444,9 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
assert(VPReductionIntrinsic::isVPReduction(VPID) &&
"No VPIntrinsic for this reduction");
auto *EltTy = cast<VectorType>(Src->getType())->getElementType();
- Value *Iden = getRecurrenceIdentity(Kind, EltTy, Builder.getFastMathFlags());
+ Function *Fn = Builder.GetInsertBlock()->getParent();
+ Value *Iden = getRecurrenceIdentity(
+ Kind, EltTy, Builder.getFastMathFlags() | Fn->getFMFFromFnAttribute());
Value *Ops[] = {Iden, Src, Mask, EVL};
return Builder.CreateIntrinsic(EltTy, VPID, Ops);
}
diff --git a/llvm/test/Transforms/LoopVectorize/reduction.ll b/llvm/test/Transforms/LoopVectorize/reduction.ll
index 65d57015b0140..dffcca89f9813 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction.ll
@@ -1335,3 +1335,62 @@ exit:
%lcssa.exit = phi i64 [ %phi.sum.next, %loop.latch ]
ret i64 %lcssa.exit
}
+
+define float @reduction_sum_float(i64 %n, ptr %a) #0 {
+; CHECK-LABEL: define float @reduction_sum_float(
+; CHECK-SAME: i64 [[N:%.*]], ptr [[A:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[N]], 4
+; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK: vector.ph:
+; CHECK-NEXT: [[N_VEC:%.*]] = and i64 [[N]], -4
+; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
+; CHECK: vector.body:
+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x float> [ <float 0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00>, [[VECTOR_PH]] ], [ [[TMP1:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds float, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x float>, ptr [[TMP0]], align 4
+; CHECK-NEXT: [[TMP1]] = fadd <4 x float> [[VEC_PHI]], [[WIDE_LOAD]]
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP2]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP27:![0-9]+]]
+; CHECK: middle.block:
+; CHECK-NEXT: [[TMP3:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[TMP1]])
+; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[N]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
+; CHECK: scalar.ph:
+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi float [ [[TMP3]], [[MIDDLE_BLOCK]] ], [ 0.000000e+00, [[ENTRY]] ]
+; CHECK-NEXT: br label [[LOOP:%.*]]
+; CHECK: loop:
+; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ]
+; CHECK-NEXT: [[RDX:%.*]] = phi float [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ], [ [[SUM:%.*]], [[LOOP]] ]
+; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds float, ptr [[A]], i64 [[IV]]
+; CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[ARRAYIDX]], align 4
+; CHECK-NEXT: [[SUM]] = fadd float [[RDX]], [[TMP4]]
+; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1
+; CHECK-NEXT: [[EXITCOND:%.*]] = icmp eq i64 [[IV_NEXT]], [[N]]
+; CHECK-NEXT: br i1 [[EXITCOND]], label [[EXIT]], label [[LOOP]], !llvm.loop [[LOOP28:![0-9]+]]
+; CHECK: exit:
+; CHECK-NEXT: [[SUM_LCSSA:%.*]] = phi float [ [[SUM]], [[LOOP]] ], [ [[TMP3]], [[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: ret float [[SUM_LCSSA]]
+;
+entry:
+ br label %loop
+
+loop:
+ %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ]
+ %rdx = phi float [ 0.0, %entry ], [ %sum, %loop ]
+ %arrayidx = getelementptr inbounds float, ptr %a, i64 %iv
+ %0 = load float, ptr %arrayidx, align 4
+ %sum = fadd float %rdx, %0
+ %iv.next = add nuw nsw i64 %iv, 1
+ %exitcond = icmp eq i64 %iv.next, %n
+ br i1 %exitcond, label %exit, label %loop
+
+exit:
+ %sum.lcssa = phi float [ %sum, %loop ]
+ ret float %sum.lcssa
+}
+
+attributes #0 = { "no-signed-zeros-fp-math"="true" }
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No new uses of the fast math function attributes should be introduced. They are actively being removed
I did this because I hit an assertion when obtaining the identity in the downstream toolchain.
Since upstream doesn’t yet call getRecurrenceIdentity for RecurKind::FMin/RecurKind::FMax, this cannot be reproduced in upstream. We can only use fadd as a simple example. |
FastMathFlags FMF; | ||
FMF.setNoNaNs( | ||
F.getFnAttribute("no-nans-fp-math").getValueAsBool()); | ||
FMF.setNoSignedZeros( | ||
F.getFnAttribute("no-signed-zeros-fp-math").getValueAsBool()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be deleted without replacement
Based on #162628