Skip to content

Commit 4c6fd63

Browse files
[DAGCombiner] Eliminate fp casts if we have the right fast math flags
When floating-point operations are legalized to operations of a higher precision (e.g. f16 fadd being legalized to f32 fadd) then we get narrowing then widening operations between each operation. With the appropriate fast math flags (nnan ninf contract) we can eliminate these casts.
1 parent 8d333e1 commit 4c6fd63

File tree

15 files changed

+519
-462
lines changed

15 files changed

+519
-462
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18455,7 +18455,45 @@ SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
1845518455
return SDValue();
1845618456
}
1845718457

18458+
// Eliminate a floating-point widening of a narrowed value if the fast math
18459+
// flags allow it.
18460+
static SDValue eliminateFPCastPair(SDNode *N) {
18461+
SDValue N0 = N->getOperand(0);
18462+
EVT VT = N->getValueType(0);
18463+
18464+
unsigned NarrowingOp;
18465+
switch (N->getOpcode()) {
18466+
case ISD::FP16_TO_FP:
18467+
NarrowingOp = ISD::FP_TO_FP16;
18468+
break;
18469+
case ISD::BF16_TO_FP:
18470+
NarrowingOp = ISD::FP_TO_BF16;
18471+
break;
18472+
case ISD::FP_EXTEND:
18473+
NarrowingOp = ISD::FP_ROUND;
18474+
break;
18475+
default:
18476+
llvm_unreachable("Expected widening FP cast");
18477+
}
18478+
18479+
if (N0.getOpcode() == NarrowingOp && N0.getOperand(0).getValueType() == VT) {
18480+
const SDNodeFlags SrcFlags = N0->getFlags();
18481+
const SDNodeFlags DstFlags = N->getFlags();
18482+
// Narrowing can introduce inf and change the encoding of a nan, so the
18483+
// destination must have the nnan and ninf flags to indicate that we don't
18484+
// need to care about that. We are also removing a rounding step, and that
18485+
// requires both the source and destination to allow contraction.
18486+
if (DstFlags.hasNoNaNs() && DstFlags.hasNoInfs() &&
18487+
SrcFlags.hasAllowContract() && DstFlags.hasAllowContract()) {
18488+
return N0.getOperand(0);
18489+
}
18490+
}
18491+
18492+
return SDValue();
18493+
}
18494+
1845818495
SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
18496+
SelectionDAG::FlagInserter FlagsInserter(DAG, N);
1845918497
SDValue N0 = N->getOperand(0);
1846018498
EVT VT = N->getValueType(0);
1846118499
SDLoc DL(N);
@@ -18507,6 +18545,9 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
1850718545
if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
1850818546
return NewVSel;
1850918547

18548+
if (SDValue CastEliminated = eliminateFPCastPair(N))
18549+
return CastEliminated;
18550+
1851018551
return SDValue();
1851118552
}
1851218553

@@ -27209,6 +27250,7 @@ SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
2720927250
}
2721027251

2721127252
SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
27253+
SelectionDAG::FlagInserter FlagsInserter(DAG, N);
2721227254
auto Op = N->getOpcode();
2721327255
assert((Op == ISD::FP16_TO_FP || Op == ISD::BF16_TO_FP) &&
2721427256
"opcode should be FP16_TO_FP or BF16_TO_FP.");
@@ -27223,6 +27265,9 @@ SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
2722327265
}
2722427266
}
2722527267

27268+
if (SDValue CastEliminated = eliminateFPCastPair(N))
27269+
return CastEliminated;
27270+
2722627271
// Sometimes constants manage to survive very late in the pipeline, e.g.,
2722727272
// because they are wrapped inside the <1 x f16> type. Try one last time to
2722827273
// get rid of them.

llvm/test/CodeGen/AArch64/f16-instructions.ll

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,8 @@ define half @test_fmadd(half %a, half %b, half %c) #0 {
8484
; CHECK-CVT-SD: // %bb.0:
8585
; CHECK-CVT-SD-NEXT: fcvt s1, h1
8686
; CHECK-CVT-SD-NEXT: fcvt s0, h0
87-
; CHECK-CVT-SD-NEXT: fmul s0, s0, s1
88-
; CHECK-CVT-SD-NEXT: fcvt s1, h2
89-
; CHECK-CVT-SD-NEXT: fcvt h0, s0
90-
; CHECK-CVT-SD-NEXT: fcvt s0, h0
91-
; CHECK-CVT-SD-NEXT: fadd s0, s0, s1
87+
; CHECK-CVT-SD-NEXT: fcvt s2, h2
88+
; CHECK-CVT-SD-NEXT: fmadd s0, s0, s1, s2
9289
; CHECK-CVT-SD-NEXT: fcvt h0, s0
9390
; CHECK-CVT-SD-NEXT: ret
9491
;
@@ -1248,6 +1245,15 @@ define half @test_atan(half %a) #0 {
12481245
}
12491246

12501247
define half @test_atan2(half %a, half %b) #0 {
1248+
; CHECK-LABEL: test_atan2:
1249+
; CHECK: // %bb.0:
1250+
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
1251+
; CHECK-NEXT: fcvt s0, h0
1252+
; CHECK-NEXT: fcvt s1, h1
1253+
; CHECK-NEXT: bl atan2f
1254+
; CHECK-NEXT: fcvt h0, s0
1255+
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
1256+
; CHECK-NEXT: ret
12511257
%r = call half @llvm.atan2.f16(half %a, half %b)
12521258
ret half %r
12531259
}

llvm/test/CodeGen/AArch64/fmla.ll

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,11 +1114,8 @@ define half @fmul_f16(half %a, half %b, half %c) {
11141114
; CHECK-SD-NOFP16: // %bb.0: // %entry
11151115
; CHECK-SD-NOFP16-NEXT: fcvt s1, h1
11161116
; CHECK-SD-NOFP16-NEXT: fcvt s0, h0
1117-
; CHECK-SD-NOFP16-NEXT: fmul s0, s0, s1
1118-
; CHECK-SD-NOFP16-NEXT: fcvt s1, h2
1119-
; CHECK-SD-NOFP16-NEXT: fcvt h0, s0
1120-
; CHECK-SD-NOFP16-NEXT: fcvt s0, h0
1121-
; CHECK-SD-NOFP16-NEXT: fadd s0, s0, s1
1117+
; CHECK-SD-NOFP16-NEXT: fcvt s2, h2
1118+
; CHECK-SD-NOFP16-NEXT: fmadd s0, s0, s1, s2
11221119
; CHECK-SD-NOFP16-NEXT: fcvt h0, s0
11231120
; CHECK-SD-NOFP16-NEXT: ret
11241121
;

llvm/test/CodeGen/AArch64/fp16_fast_math.ll

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,112 @@ entry:
8888
%add = fadd ninf half %x, %y
8989
ret half %add
9090
}
91+
92+
; Check that when we have the right fast math flags the converts in between the
93+
; two fadds are removed.
94+
95+
define half @normal_fadd_sequence(half %x, half %y, half %z) {
96+
; CHECK-CVT-LABEL: name: normal_fadd_sequence
97+
; CHECK-CVT: bb.0.entry:
98+
; CHECK-CVT-NEXT: liveins: $h0, $h1, $h2
99+
; CHECK-CVT-NEXT: {{ $}}
100+
; CHECK-CVT-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
101+
; CHECK-CVT-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
102+
; CHECK-CVT-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
103+
; CHECK-CVT-NEXT: [[FCVTSHr:%[0-9]+]]:fpr32 = nofpexcept FCVTSHr [[COPY1]], implicit $fpcr
104+
; CHECK-CVT-NEXT: [[FCVTSHr1:%[0-9]+]]:fpr32 = nofpexcept FCVTSHr [[COPY2]], implicit $fpcr
105+
; CHECK-CVT-NEXT: [[FADDSrr:%[0-9]+]]:fpr32 = nofpexcept FADDSrr killed [[FCVTSHr1]], killed [[FCVTSHr]], implicit $fpcr
106+
; CHECK-CVT-NEXT: [[FCVTHSr:%[0-9]+]]:fpr16 = nofpexcept FCVTHSr killed [[FADDSrr]], implicit $fpcr
107+
; CHECK-CVT-NEXT: [[FCVTSHr2:%[0-9]+]]:fpr32 = nofpexcept FCVTSHr killed [[FCVTHSr]], implicit $fpcr
108+
; CHECK-CVT-NEXT: [[FCVTSHr3:%[0-9]+]]:fpr32 = nofpexcept FCVTSHr [[COPY]], implicit $fpcr
109+
; CHECK-CVT-NEXT: [[FADDSrr1:%[0-9]+]]:fpr32 = nofpexcept FADDSrr killed [[FCVTSHr2]], killed [[FCVTSHr3]], implicit $fpcr
110+
; CHECK-CVT-NEXT: [[FCVTHSr1:%[0-9]+]]:fpr16 = nofpexcept FCVTHSr killed [[FADDSrr1]], implicit $fpcr
111+
; CHECK-CVT-NEXT: $h0 = COPY [[FCVTHSr1]]
112+
; CHECK-CVT-NEXT: RET_ReallyLR implicit $h0
113+
;
114+
; CHECK-FP16-LABEL: name: normal_fadd_sequence
115+
; CHECK-FP16: bb.0.entry:
116+
; CHECK-FP16-NEXT: liveins: $h0, $h1, $h2
117+
; CHECK-FP16-NEXT: {{ $}}
118+
; CHECK-FP16-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
119+
; CHECK-FP16-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
120+
; CHECK-FP16-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
121+
; CHECK-FP16-NEXT: [[FADDHrr:%[0-9]+]]:fpr16 = nofpexcept FADDHrr [[COPY2]], [[COPY1]], implicit $fpcr
122+
; CHECK-FP16-NEXT: [[FADDHrr1:%[0-9]+]]:fpr16 = nofpexcept FADDHrr killed [[FADDHrr]], [[COPY]], implicit $fpcr
123+
; CHECK-FP16-NEXT: $h0 = COPY [[FADDHrr1]]
124+
; CHECK-FP16-NEXT: RET_ReallyLR implicit $h0
125+
entry:
126+
%add1 = fadd half %x, %y
127+
%add2 = fadd half %add1, %z
128+
ret half %add2
129+
}
130+
131+
define half @nnan_ninf_contract_fadd_sequence(half %x, half %y, half %z) {
132+
; CHECK-CVT-LABEL: name: nnan_ninf_contract_fadd_sequence
133+
; CHECK-CVT: bb.0.entry:
134+
; CHECK-CVT-NEXT: liveins: $h0, $h1, $h2
135+
; CHECK-CVT-NEXT: {{ $}}
136+
; CHECK-CVT-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
137+
; CHECK-CVT-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
138+
; CHECK-CVT-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
139+
; CHECK-CVT-NEXT: [[FCVTSHr:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FCVTSHr [[COPY1]], implicit $fpcr
140+
; CHECK-CVT-NEXT: [[FCVTSHr1:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FCVTSHr [[COPY2]], implicit $fpcr
141+
; CHECK-CVT-NEXT: [[FADDSrr:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FADDSrr killed [[FCVTSHr1]], killed [[FCVTSHr]], implicit $fpcr
142+
; CHECK-CVT-NEXT: [[FCVTSHr2:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FCVTSHr [[COPY]], implicit $fpcr
143+
; CHECK-CVT-NEXT: [[FADDSrr1:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FADDSrr killed [[FADDSrr]], killed [[FCVTSHr2]], implicit $fpcr
144+
; CHECK-CVT-NEXT: [[FCVTHSr:%[0-9]+]]:fpr16 = nnan ninf contract nofpexcept FCVTHSr killed [[FADDSrr1]], implicit $fpcr
145+
; CHECK-CVT-NEXT: $h0 = COPY [[FCVTHSr]]
146+
; CHECK-CVT-NEXT: RET_ReallyLR implicit $h0
147+
;
148+
; CHECK-FP16-LABEL: name: nnan_ninf_contract_fadd_sequence
149+
; CHECK-FP16: bb.0.entry:
150+
; CHECK-FP16-NEXT: liveins: $h0, $h1, $h2
151+
; CHECK-FP16-NEXT: {{ $}}
152+
; CHECK-FP16-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
153+
; CHECK-FP16-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
154+
; CHECK-FP16-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
155+
; CHECK-FP16-NEXT: [[FADDHrr:%[0-9]+]]:fpr16 = nnan ninf contract nofpexcept FADDHrr [[COPY2]], [[COPY1]], implicit $fpcr
156+
; CHECK-FP16-NEXT: [[FADDHrr1:%[0-9]+]]:fpr16 = nnan ninf contract nofpexcept FADDHrr killed [[FADDHrr]], [[COPY]], implicit $fpcr
157+
; CHECK-FP16-NEXT: $h0 = COPY [[FADDHrr1]]
158+
; CHECK-FP16-NEXT: RET_ReallyLR implicit $h0
159+
entry:
160+
%add1 = fadd nnan ninf contract half %x, %y
161+
%add2 = fadd nnan ninf contract half %add1, %z
162+
ret half %add2
163+
}
164+
165+
define half @ninf_fadd_sequence(half %x, half %y, half %z) {
166+
; CHECK-CVT-LABEL: name: ninf_fadd_sequence
167+
; CHECK-CVT: bb.0.entry:
168+
; CHECK-CVT-NEXT: liveins: $h0, $h1, $h2
169+
; CHECK-CVT-NEXT: {{ $}}
170+
; CHECK-CVT-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
171+
; CHECK-CVT-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
172+
; CHECK-CVT-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
173+
; CHECK-CVT-NEXT: [[FCVTSHr:%[0-9]+]]:fpr32 = ninf nofpexcept FCVTSHr [[COPY1]], implicit $fpcr
174+
; CHECK-CVT-NEXT: [[FCVTSHr1:%[0-9]+]]:fpr32 = ninf nofpexcept FCVTSHr [[COPY2]], implicit $fpcr
175+
; CHECK-CVT-NEXT: [[FADDSrr:%[0-9]+]]:fpr32 = ninf nofpexcept FADDSrr killed [[FCVTSHr1]], killed [[FCVTSHr]], implicit $fpcr
176+
; CHECK-CVT-NEXT: [[FCVTHSr:%[0-9]+]]:fpr16 = ninf nofpexcept FCVTHSr killed [[FADDSrr]], implicit $fpcr
177+
; CHECK-CVT-NEXT: [[FCVTSHr2:%[0-9]+]]:fpr32 = ninf nofpexcept FCVTSHr killed [[FCVTHSr]], implicit $fpcr
178+
; CHECK-CVT-NEXT: [[FCVTSHr3:%[0-9]+]]:fpr32 = ninf nofpexcept FCVTSHr [[COPY]], implicit $fpcr
179+
; CHECK-CVT-NEXT: [[FADDSrr1:%[0-9]+]]:fpr32 = ninf nofpexcept FADDSrr killed [[FCVTSHr2]], killed [[FCVTSHr3]], implicit $fpcr
180+
; CHECK-CVT-NEXT: [[FCVTHSr1:%[0-9]+]]:fpr16 = ninf nofpexcept FCVTHSr killed [[FADDSrr1]], implicit $fpcr
181+
; CHECK-CVT-NEXT: $h0 = COPY [[FCVTHSr1]]
182+
; CHECK-CVT-NEXT: RET_ReallyLR implicit $h0
183+
;
184+
; CHECK-FP16-LABEL: name: ninf_fadd_sequence
185+
; CHECK-FP16: bb.0.entry:
186+
; CHECK-FP16-NEXT: liveins: $h0, $h1, $h2
187+
; CHECK-FP16-NEXT: {{ $}}
188+
; CHECK-FP16-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
189+
; CHECK-FP16-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
190+
; CHECK-FP16-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
191+
; CHECK-FP16-NEXT: [[FADDHrr:%[0-9]+]]:fpr16 = ninf nofpexcept FADDHrr [[COPY2]], [[COPY1]], implicit $fpcr
192+
; CHECK-FP16-NEXT: [[FADDHrr1:%[0-9]+]]:fpr16 = ninf nofpexcept FADDHrr killed [[FADDHrr]], [[COPY]], implicit $fpcr
193+
; CHECK-FP16-NEXT: $h0 = COPY [[FADDHrr1]]
194+
; CHECK-FP16-NEXT: RET_ReallyLR implicit $h0
195+
entry:
196+
%add1 = fadd ninf half %x, %y
197+
%add2 = fadd ninf half %add1, %z
198+
ret half %add2
199+
}

0 commit comments

Comments
 (0)