Skip to content

Commit 812e02a

Browse files
authored
[NVPTX] Use fast-math flags when lowering sin, cos, frem (#133121)
Update the lowering rules for sin, cos, and frem to respect the instruction-level flags in addition to the global and function-level options. For sin and cos, the TableGen lowering has been updated to check the `afn` flag on the node. The lowering for frem has been pulled to custom instruction legalization in order to allow for DAG Combiner optimizations to operate over the expanded instructions.
1 parent c0952a9 commit 812e02a

File tree

6 files changed

+361
-113
lines changed

6 files changed

+361
-113
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "NVPTXTargetMachine.h"
1919
#include "NVPTXTargetObjectFile.h"
2020
#include "NVPTXUtilities.h"
21+
#include "llvm/ADT/APFloat.h"
2122
#include "llvm/ADT/APInt.h"
2223
#include "llvm/ADT/STLExtras.h"
2324
#include "llvm/ADT/SmallVector.h"
@@ -932,6 +933,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
932933
setOperationAction(Op, MVT::bf16, Promote);
933934
AddPromotedToType(Op, MVT::bf16, MVT::f32);
934935
}
936+
setOperationAction(ISD::FREM, {MVT::f32, MVT::f64}, Custom);
935937

936938
setOperationAction(ISD::FABS, {MVT::f32, MVT::f64}, Legal);
937939
if (STI.getPTXVersion() >= 65) {
@@ -2819,6 +2821,34 @@ static SDValue lowerROT(SDValue Op, SelectionDAG &DAG) {
28192821
SDLoc(Op), Opcode, DAG);
28202822
}
28212823

2824+
static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG,
2825+
bool AllowUnsafeFPMath) {
2826+
// Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
2827+
// i.e. "poor man's fmod()". When y is infinite, x is returned. This matches
2828+
// the semantics of LLVM's frem.
2829+
SDLoc DL(Op);
2830+
SDValue X = Op->getOperand(0);
2831+
SDValue Y = Op->getOperand(1);
2832+
EVT Ty = Op.getValueType();
2833+
2834+
SDValue Div = DAG.getNode(ISD::FDIV, DL, Ty, X, Y);
2835+
SDValue Trunc = DAG.getNode(ISD::FTRUNC, DL, Ty, Div);
2836+
SDValue Mul =
2837+
DAG.getNode(ISD::FMUL, DL, Ty, Trunc, Y, SDNodeFlags::AllowContract);
2838+
SDValue Sub =
2839+
DAG.getNode(ISD::FSUB, DL, Ty, X, Mul, SDNodeFlags::AllowContract);
2840+
2841+
if (AllowUnsafeFPMath || Op->getFlags().hasNoInfs())
2842+
return Sub;
2843+
2844+
// If Y is infinite, return X
2845+
SDValue AbsY = DAG.getNode(ISD::FABS, DL, Ty, Y);
2846+
SDValue Inf =
2847+
DAG.getConstantFP(APFloat::getInf(Ty.getFltSemantics()), DL, Ty);
2848+
SDValue IsInf = DAG.getSetCC(DL, MVT::i1, AbsY, Inf, ISD::SETEQ);
2849+
return DAG.getSelect(DL, Ty, IsInf, X, Sub);
2850+
}
2851+
28222852
SDValue
28232853
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28242854
switch (Op.getOpcode()) {
@@ -2913,6 +2943,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
29132943
case ISD::CTPOP:
29142944
case ISD::CTLZ:
29152945
return lowerCTLZCTPOP(Op, DAG);
2946+
case ISD::FREM:
2947+
return lowerFREM(Op, DAG, allowUnsafeFPMath(DAG.getMachineFunction()));
29162948

29172949
default:
29182950
llvm_unreachable("Custom lowering not defined for operation");

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 21 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,6 @@ def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
150150

151151
def doMulWide : Predicate<"doMulWide">;
152152

153-
def allowUnsafeFPMath : Predicate<"allowUnsafeFPMath()">;
154-
def noUnsafeFPMath : Predicate<"!allowUnsafeFPMath()">;
155-
156153
def do_DIVF32_APPROX : Predicate<"getDivF32Level()==0">;
157154
def do_DIVF32_FULL : Predicate<"getDivF32Level()==1">;
158155

@@ -211,6 +208,12 @@ class ValueToRegClass<ValueType T> {
211208
// Some Common Instruction Class Templates
212209
//===----------------------------------------------------------------------===//
213210

211+
class OneUse1<SDPatternOperator operator>
212+
: PatFrag<(ops node:$A), (operator node:$A), [{ return N->hasOneUse(); }]>;
213+
214+
class fpimm_pos_inf<ValueType vt>
215+
: FPImmLeaf<vt, [{ return Imm.isPosInfinity(); }]>;
216+
214217
// Utility class to wrap up information about a register and DAG type for more
215218
// convenient iteration and parameterization
216219
class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm> {
@@ -442,7 +445,7 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
442445
class BinOpAllowsFMA<SDPatternOperator operator>
443446
: PatFrag<(ops node:$A, node:$B),
444447
(operator node:$A, node:$B), [{
445-
return allowFMA() || N->getFlags().hasAllowContract();;
448+
return allowFMA() || N->getFlags().hasAllowContract();
446449
}]>;
447450

448451
multiclass F3_fma_component<string op_str, SDNode op_node> {
@@ -693,10 +696,7 @@ let hasSideEffects = false in {
693696
defm CVT_to_tf32_rz_relu_satf : CVT_TO_TF32<"rz.relu.satfinite", [hasPTX<86>, hasSM<100>]>;
694697
}
695698

696-
def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{
697-
return N->hasOneUse();
698-
}]>;
699-
699+
def fpround_oneuse : OneUse1<fpround>;
700700
def : Pat<(v2bf16 (build_vector (bf16 (fpround_oneuse f32:$lo)),
701701
(bf16 (fpround_oneuse f32:$hi)))),
702702
(CVT_bf16x2_f32 $hi, $lo, CvtRN)>,
@@ -786,18 +786,14 @@ def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
786786
// Test Instructions
787787
//-----------------------------------
788788

789+
def fabs_oneuse : OneUse1<fabs>;
790+
789791
def TESTINF_f32r : NVPTXInst<(outs Int1Regs:$p), (ins Float32Regs:$a),
790792
"testp.infinite.f32 \t$p, $a;",
791-
[]>;
792-
def TESTINF_f32i : NVPTXInst<(outs Int1Regs:$p), (ins f32imm:$a),
793-
"testp.infinite.f32 \t$p, $a;",
794-
[]>;
793+
[(set i1:$p, (seteq (fabs_oneuse f32:$a), fpimm_pos_inf<f32>))]>;
795794
def TESTINF_f64r : NVPTXInst<(outs Int1Regs:$p), (ins Float64Regs:$a),
796795
"testp.infinite.f64 \t$p, $a;",
797-
[]>;
798-
def TESTINF_f64i : NVPTXInst<(outs Int1Regs:$p), (ins f64imm:$a),
799-
"testp.infinite.f64 \t$p, $a;",
800-
[]>;
796+
[(set i1:$p, (seteq (fabs_oneuse f64:$a), fpimm_pos_inf<f64>))]>;
801797

802798
//-----------------------------------
803799
// Integer Arithmetic
@@ -1362,99 +1358,19 @@ defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
13621358
defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
13631359

13641360
// sin/cos
1361+
1362+
class UnaryOpAllowsApproxFn<SDPatternOperator operator>
1363+
: PatFrag<(ops node:$A),
1364+
(operator node:$A), [{
1365+
return allowUnsafeFPMath() || N->getFlags().hasApproximateFuncs();
1366+
}]>;
1367+
13651368
def SINF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
13661369
"sin.approx.f32 \t$dst, $src;",
1367-
[(set f32:$dst, (fsin f32:$src))]>,
1368-
Requires<[allowUnsafeFPMath]>;
1370+
[(set f32:$dst, (UnaryOpAllowsApproxFn<fsin> f32:$src))]>;
13691371
def COSF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
13701372
"cos.approx.f32 \t$dst, $src;",
1371-
[(set f32:$dst, (fcos f32:$src))]>,
1372-
Requires<[allowUnsafeFPMath]>;
1373-
1374-
// Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
1375-
// i.e. "poor man's fmod()". When y is infinite, x is returned. This matches the
1376-
// semantics of LLVM's frem.
1377-
1378-
// frem - f32 FTZ
1379-
def : Pat<(frem f32:$x, f32:$y),
1380-
(FSUBf32rr_ftz $x, (FMULf32rr_ftz (CVT_f32_f32
1381-
(FDIV32rr_prec_ftz $x, $y), CvtRZI_FTZ),
1382-
$y))>,
1383-
Requires<[doF32FTZ, allowUnsafeFPMath]>;
1384-
def : Pat<(frem f32:$x, fpimm:$y),
1385-
(FSUBf32rr_ftz $x, (FMULf32ri_ftz (CVT_f32_f32
1386-
(FDIV32ri_prec_ftz $x, fpimm:$y), CvtRZI_FTZ),
1387-
fpimm:$y))>,
1388-
Requires<[doF32FTZ, allowUnsafeFPMath]>;
1389-
1390-
def : Pat<(frem f32:$x, f32:$y),
1391-
(SELP_f32rr $x,
1392-
(FSUBf32rr_ftz $x, (FMULf32rr_ftz (CVT_f32_f32
1393-
(FDIV32rr_prec_ftz $x, $y), CvtRZI_FTZ),
1394-
$y)),
1395-
(TESTINF_f32r $y))>,
1396-
Requires<[doF32FTZ, noUnsafeFPMath]>;
1397-
def : Pat<(frem f32:$x, fpimm:$y),
1398-
(SELP_f32rr $x,
1399-
(FSUBf32rr_ftz $x, (FMULf32ri_ftz (CVT_f32_f32
1400-
(FDIV32ri_prec_ftz $x, fpimm:$y), CvtRZI_FTZ),
1401-
fpimm:$y)),
1402-
(TESTINF_f32i fpimm:$y))>,
1403-
Requires<[doF32FTZ, noUnsafeFPMath]>;
1404-
1405-
// frem - f32
1406-
def : Pat<(frem f32:$x, f32:$y),
1407-
(FSUBf32rr $x, (FMULf32rr (CVT_f32_f32
1408-
(FDIV32rr_prec $x, $y), CvtRZI),
1409-
$y))>,
1410-
Requires<[allowUnsafeFPMath]>;
1411-
def : Pat<(frem f32:$x, fpimm:$y),
1412-
(FSUBf32rr $x, (FMULf32ri (CVT_f32_f32
1413-
(FDIV32ri_prec $x, fpimm:$y), CvtRZI),
1414-
fpimm:$y))>,
1415-
Requires<[allowUnsafeFPMath]>;
1416-
1417-
def : Pat<(frem f32:$x, f32:$y),
1418-
(SELP_f32rr $x,
1419-
(FSUBf32rr $x, (FMULf32rr (CVT_f32_f32
1420-
(FDIV32rr_prec $x, $y), CvtRZI),
1421-
$y)),
1422-
(TESTINF_f32r Float32Regs:$y))>,
1423-
Requires<[noUnsafeFPMath]>;
1424-
def : Pat<(frem f32:$x, fpimm:$y),
1425-
(SELP_f32rr $x,
1426-
(FSUBf32rr $x, (FMULf32ri (CVT_f32_f32
1427-
(FDIV32ri_prec $x, fpimm:$y), CvtRZI),
1428-
fpimm:$y)),
1429-
(TESTINF_f32i fpimm:$y))>,
1430-
Requires<[noUnsafeFPMath]>;
1431-
1432-
// frem - f64
1433-
def : Pat<(frem f64:$x, f64:$y),
1434-
(FSUBf64rr $x, (FMULf64rr (CVT_f64_f64
1435-
(FDIV64rr $x, $y), CvtRZI),
1436-
$y))>,
1437-
Requires<[allowUnsafeFPMath]>;
1438-
def : Pat<(frem f64:$x, fpimm:$y),
1439-
(FSUBf64rr $x, (FMULf64ri (CVT_f64_f64
1440-
(FDIV64ri $x, fpimm:$y), CvtRZI),
1441-
fpimm:$y))>,
1442-
Requires<[allowUnsafeFPMath]>;
1443-
1444-
def : Pat<(frem f64:$x, f64:$y),
1445-
(SELP_f64rr $x,
1446-
(FSUBf64rr $x, (FMULf64rr (CVT_f64_f64
1447-
(FDIV64rr $x, $y), CvtRZI),
1448-
$y)),
1449-
(TESTINF_f64r Float64Regs:$y))>,
1450-
Requires<[noUnsafeFPMath]>;
1451-
def : Pat<(frem f64:$x, fpimm:$y),
1452-
(SELP_f64rr $x,
1453-
(FSUBf64rr $x, (FMULf64ri (CVT_f64_f64
1454-
(FDIV64ri $x, fpimm:$y), CvtRZI),
1455-
fpimm:$y)),
1456-
(TESTINF_f64r $y))>,
1457-
Requires<[noUnsafeFPMath]>;
1373+
[(set f32:$dst, (UnaryOpAllowsApproxFn<fcos> f32:$src))]>;
14581374

14591375
//-----------------------------------
14601376
// Bitwise operations

llvm/test/CodeGen/NVPTX/f16-instructions.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,14 +200,14 @@ define half @test_fdiv(half %a, half %b) #0 {
200200
; CHECK-NOFTZ-DAG: cvt.f32.f16 [[FB:%f[0-9]+]], [[B]];
201201
; CHECK-NOFTZ-NEXT: div.rn.f32 [[D:%f[0-9]+]], [[FA]], [[FB]];
202202
; CHECK-NOFTZ-NEXT: cvt.rzi.f32.f32 [[DI:%f[0-9]+]], [[D]];
203-
; CHECK-NOFTZ-NEXT: mul.f32 [[RI:%f[0-9]+]], [[DI]], [[FB]];
204-
; CHECK-NOFTZ-NEXT: sub.f32 [[RF:%f[0-9]+]], [[FA]], [[RI]];
203+
; CHECK-NOFTZ-NEXT: neg.f32 [[DNEG:%f[0-9]+]], [[DI]];
204+
; CHECK-NOFTZ-NEXT: fma.rn.f32 [[RF:%f[0-9]+]], [[DNEG]], [[FB]], [[FA]];
205205
; CHECK-F16-FTZ-DAG: cvt.ftz.f32.f16 [[FA:%f[0-9]+]], [[A]];
206206
; CHECK-F16-FTZ-DAG: cvt.ftz.f32.f16 [[FB:%f[0-9]+]], [[B]];
207207
; CHECK-F16-FTZ-NEXT: div.rn.ftz.f32 [[D:%f[0-9]+]], [[FA]], [[FB]];
208208
; CHECK-F16-FTZ-NEXT: cvt.rzi.ftz.f32.f32 [[DI:%f[0-9]+]], [[D]];
209-
; CHECK-F16-FTZ-NEXT: mul.ftz.f32 [[RI:%f[0-9]+]], [[DI]], [[FB]];
210-
; CHECK-F16-FTZ-NEXT: sub.ftz.f32 [[RF:%f[0-9]+]], [[FA]], [[RI]];
209+
; CHECK-F16-FTZ-NEXT: neg.ftz.f32 [[DNEG:%f[0-9]+]], [[DI]];
210+
; CHECK-F16-FTZ-NEXT: fma.rn.ftz.f32 [[RF:%f[0-9]+]], [[DNEG]], [[FB]], [[FA]];
211211
; CHECK-NEXT: testp.infinite.f32 [[ISBINF:%p[0-9]+]], [[FB]];
212212
; CHECK-NEXT: selp.f32 [[RESULT:%f[0-9]+]], [[FA]], [[RF]], [[ISBINF]];
213213
; CHECK-NEXT: cvt.rn.f16.f32 [[R:%rs[0-9]+]], [[RESULT]];

llvm/test/CodeGen/NVPTX/f16x2-instructions.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -362,17 +362,17 @@ define <2 x half> @test_frem(<2 x half> %a, <2 x half> %b) #0 {
362362
; CHECK-NEXT: cvt.f32.f16 %f2, %rs4;
363363
; CHECK-NEXT: div.rn.f32 %f3, %f2, %f1;
364364
; CHECK-NEXT: cvt.rzi.f32.f32 %f4, %f3;
365-
; CHECK-NEXT: mul.f32 %f5, %f4, %f1;
366-
; CHECK-NEXT: sub.f32 %f6, %f2, %f5;
365+
; CHECK-NEXT: neg.f32 %f5, %f4;
366+
; CHECK-NEXT: fma.rn.f32 %f6, %f5, %f1, %f2;
367367
; CHECK-NEXT: testp.infinite.f32 %p1, %f1;
368368
; CHECK-NEXT: selp.f32 %f7, %f2, %f6, %p1;
369369
; CHECK-NEXT: cvt.rn.f16.f32 %rs5, %f7;
370370
; CHECK-NEXT: cvt.f32.f16 %f8, %rs1;
371371
; CHECK-NEXT: cvt.f32.f16 %f9, %rs3;
372372
; CHECK-NEXT: div.rn.f32 %f10, %f9, %f8;
373373
; CHECK-NEXT: cvt.rzi.f32.f32 %f11, %f10;
374-
; CHECK-NEXT: mul.f32 %f12, %f11, %f8;
375-
; CHECK-NEXT: sub.f32 %f13, %f9, %f12;
374+
; CHECK-NEXT: neg.f32 %f12, %f11;
375+
; CHECK-NEXT: fma.rn.f32 %f13, %f12, %f8, %f9;
376376
; CHECK-NEXT: testp.infinite.f32 %p2, %f8;
377377
; CHECK-NEXT: selp.f32 %f14, %f9, %f13, %p2;
378378
; CHECK-NEXT: cvt.rn.f16.f32 %rs6, %f14;

llvm/test/CodeGen/NVPTX/fast-math.ll

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,20 @@ define float @fadd_ftz(float %a, float %b) #1 {
131131
declare float @llvm.sin.f32(float)
132132
declare float @llvm.cos.f32(float)
133133

134+
; CHECK-LABEL: fsin_approx_afn
135+
; CHECK: sin.approx.f32
136+
define float @fsin_approx_afn(float %a) {
137+
%r = tail call afn float @llvm.sin.f32(float %a)
138+
ret float %r
139+
}
140+
141+
; CHECK-LABEL: fcos_approx_afn
142+
; CHECK: cos.approx.f32
143+
define float @fcos_approx_afn(float %a) {
144+
%r = tail call afn float @llvm.cos.f32(float %a)
145+
ret float %r
146+
}
147+
134148
; CHECK-LABEL: fsin_approx
135149
; CHECK: sin.approx.f32
136150
define float @fsin_approx(float %a) #0 {

0 commit comments

Comments
 (0)