Skip to content

Commit a6f0221

Browse files
committed
[SLP] fix fast-math-flag propagation on FP reductions
As shown in the test diffs, we could miscompile by propagating flags that did not exist in the original code. The flags required for fmin/fmax reductions will be fixed in a follow-up patch.
1 parent 39e1e53 commit a6f0221

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6820,12 +6820,18 @@ class HorizontalReduction {
68206820
if (NumReducedVals < 4)
68216821
return false;
68226822

6823-
// FIXME: Fast-math-flags should be set based on the instructions in the
6824-
// reduction (not all of 'fast' are required).
6823+
// Intersect the fast-math-flags from all reduction operations.
6824+
FastMathFlags RdxFMF;
6825+
RdxFMF.set();
6826+
for (ReductionOpsType &RdxOp : ReductionOps) {
6827+
for (Value *RdxVal : RdxOp) {
6828+
if (auto *FPMO = dyn_cast<FPMathOperator>(RdxVal))
6829+
RdxFMF &= FPMO->getFastMathFlags();
6830+
}
6831+
}
6832+
68256833
IRBuilder<> Builder(cast<Instruction>(ReductionRoot));
6826-
FastMathFlags Unsafe;
6827-
Unsafe.setFast();
6828-
Builder.setFastMathFlags(Unsafe);
6834+
Builder.setFastMathFlags(RdxFMF);
68296835

68306836
BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues;
68316837
// The same extra argument may be used several times, so log each attempt
@@ -7071,9 +7077,6 @@ class HorizontalReduction {
70717077
assert(isPowerOf2_32(ReduxWidth) &&
70727078
"We only handle power-of-two reductions for now");
70737079

7074-
// FIXME: The builder should use an FMF guard. It should not be hard-coded
7075-
// to 'fast'.
7076-
assert(Builder.getFastMathFlags().isFast() && "Expected 'fast' FMF");
70777080
return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind,
70787081
ReductionOps.back());
70797082
}

llvm/test/Transforms/SLPVectorizer/X86/horizontal.ll

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,7 +1766,6 @@ bb.1:
17661766
ret void
17671767
}
17681768

1769-
; FIXME: This is a miscompile.
17701769
; The FMF on the reduction should match the incoming insts.
17711770

17721771
define float @fadd_v4f32_fmf(float* %p) {
@@ -1776,7 +1775,7 @@ define float @fadd_v4f32_fmf(float* %p) {
17761775
; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3
17771776
; CHECK-NEXT: [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>*
17781777
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4
1779-
; CHECK-NEXT: [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
1778+
; CHECK-NEXT: [[TMP3:%.*]] = call reassoc nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
17801779
; CHECK-NEXT: ret float [[TMP3]]
17811780
;
17821781
; STORE-LABEL: @fadd_v4f32_fmf(
@@ -1785,7 +1784,7 @@ define float @fadd_v4f32_fmf(float* %p) {
17851784
; STORE-NEXT: [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3
17861785
; STORE-NEXT: [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>*
17871786
; STORE-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4
1788-
; STORE-NEXT: [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
1787+
; STORE-NEXT: [[TMP3:%.*]] = call reassoc nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
17891788
; STORE-NEXT: ret float [[TMP3]]
17901789
;
17911790
%p1 = getelementptr inbounds float, float* %p, i64 1
@@ -1801,14 +1800,18 @@ define float @fadd_v4f32_fmf(float* %p) {
18011800
ret float %add3
18021801
}
18031802

1803+
; The minimal FMF for fadd reduction are "reassoc nsz".
1804+
; Only the common FMF of all operations in the reduction propagate to the result.
1805+
; In this example, "contract nnan arcp" are dropped, but "ninf" transfers with the required flags.
1806+
18041807
define float @fadd_v4f32_fmf_intersect(float* %p) {
18051808
; CHECK-LABEL: @fadd_v4f32_fmf_intersect(
18061809
; CHECK-NEXT: [[P1:%.*]] = getelementptr inbounds float, float* [[P:%.*]], i64 1
18071810
; CHECK-NEXT: [[P2:%.*]] = getelementptr inbounds float, float* [[P]], i64 2
18081811
; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3
18091812
; CHECK-NEXT: [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>*
18101813
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4
1811-
; CHECK-NEXT: [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
1814+
; CHECK-NEXT: [[TMP3:%.*]] = call reassoc ninf nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
18121815
; CHECK-NEXT: ret float [[TMP3]]
18131816
;
18141817
; STORE-LABEL: @fadd_v4f32_fmf_intersect(
@@ -1817,7 +1820,7 @@ define float @fadd_v4f32_fmf_intersect(float* %p) {
18171820
; STORE-NEXT: [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3
18181821
; STORE-NEXT: [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>*
18191822
; STORE-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4
1820-
; STORE-NEXT: [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
1823+
; STORE-NEXT: [[TMP3:%.*]] = call reassoc ninf nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
18211824
; STORE-NEXT: ret float [[TMP3]]
18221825
;
18231826
%p1 = getelementptr inbounds float, float* %p, i64 1

0 commit comments

Comments
 (0)