@@ -150,9 +150,6 @@ def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
150150
151151def doMulWide : Predicate<"doMulWide">;
152152
153- def allowUnsafeFPMath : Predicate<"allowUnsafeFPMath()">;
154- def noUnsafeFPMath : Predicate<"!allowUnsafeFPMath()">;
155-
156153def do_DIVF32_APPROX : Predicate<"getDivF32Level()==0">;
157154def do_DIVF32_FULL : Predicate<"getDivF32Level()==1">;
158155
@@ -211,6 +208,12 @@ class ValueToRegClass<ValueType T> {
211208// Some Common Instruction Class Templates
212209//===----------------------------------------------------------------------===//
213210
211+ class OneUse1<SDPatternOperator operator>
212+ : PatFrag<(ops node:$A), (operator node:$A), [{ return N->hasOneUse(); }]>;
213+
214+ class fpimm_pos_inf<ValueType vt>
215+ : FPImmLeaf<vt, [{ return Imm.isPosInfinity(); }]>;
216+
214217// Utility class to wrap up information about a register and DAG type for more
215218// convenient iteration and parameterization
216219class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm> {
@@ -442,7 +445,7 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
442445class BinOpAllowsFMA<SDPatternOperator operator>
443446 : PatFrag<(ops node:$A, node:$B),
444447 (operator node:$A, node:$B), [{
445- return allowFMA() || N->getFlags().hasAllowContract();;
448+ return allowFMA() || N->getFlags().hasAllowContract();
446449}]>;
447450
448451multiclass F3_fma_component<string op_str, SDNode op_node> {
@@ -693,10 +696,7 @@ let hasSideEffects = false in {
693696 defm CVT_to_tf32_rz_relu_satf : CVT_TO_TF32<"rz.relu.satfinite", [hasPTX<86>, hasSM<100>]>;
694697}
695698
696- def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{
697- return N->hasOneUse();
698- }]>;
699-
699+ def fpround_oneuse : OneUse1<fpround>;
700700def : Pat<(v2bf16 (build_vector (bf16 (fpround_oneuse f32:$lo)),
701701 (bf16 (fpround_oneuse f32:$hi)))),
702702 (CVT_bf16x2_f32 $hi, $lo, CvtRN)>,
@@ -786,18 +786,14 @@ def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
786786// Test Instructions
787787//-----------------------------------
788788
789+ def fabs_oneuse : OneUse1<fabs>;
790+
789791def TESTINF_f32r : NVPTXInst<(outs Int1Regs:$p), (ins Float32Regs:$a),
790792 "testp.infinite.f32 \t$p, $a;",
791- []>;
792- def TESTINF_f32i : NVPTXInst<(outs Int1Regs:$p), (ins f32imm:$a),
793- "testp.infinite.f32 \t$p, $a;",
794- []>;
793+ [(set i1:$p, (seteq (fabs_oneuse f32:$a), fpimm_pos_inf<f32>))]>;
795794def TESTINF_f64r : NVPTXInst<(outs Int1Regs:$p), (ins Float64Regs:$a),
796795 "testp.infinite.f64 \t$p, $a;",
797- []>;
798- def TESTINF_f64i : NVPTXInst<(outs Int1Regs:$p), (ins f64imm:$a),
799- "testp.infinite.f64 \t$p, $a;",
800- []>;
796+ [(set i1:$p, (seteq (fabs_oneuse f64:$a), fpimm_pos_inf<f64>))]>;
801797
802798//-----------------------------------
803799// Integer Arithmetic
@@ -1362,99 +1358,19 @@ defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
13621358defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
13631359
13641360// sin/cos
1361+
1362+ class UnaryOpAllowsApproxFn<SDPatternOperator operator>
1363+ : PatFrag<(ops node:$A),
1364+ (operator node:$A), [{
1365+ return allowUnsafeFPMath() || N->getFlags().hasApproximateFuncs();
1366+ }]>;
1367+
13651368def SINF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
13661369 "sin.approx.f32 \t$dst, $src;",
1367- [(set f32:$dst, (fsin f32:$src))]>,
1368- Requires<[allowUnsafeFPMath]>;
1370+ [(set f32:$dst, (UnaryOpAllowsApproxFn<fsin> f32:$src))]>;
13691371def COSF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
13701372 "cos.approx.f32 \t$dst, $src;",
1371- [(set f32:$dst, (fcos f32:$src))]>,
1372- Requires<[allowUnsafeFPMath]>;
1373-
1374- // Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
1375- // i.e. "poor man's fmod()". When y is infinite, x is returned. This matches the
1376- // semantics of LLVM's frem.
1377-
1378- // frem - f32 FTZ
1379- def : Pat<(frem f32:$x, f32:$y),
1380- (FSUBf32rr_ftz $x, (FMULf32rr_ftz (CVT_f32_f32
1381- (FDIV32rr_prec_ftz $x, $y), CvtRZI_FTZ),
1382- $y))>,
1383- Requires<[doF32FTZ, allowUnsafeFPMath]>;
1384- def : Pat<(frem f32:$x, fpimm:$y),
1385- (FSUBf32rr_ftz $x, (FMULf32ri_ftz (CVT_f32_f32
1386- (FDIV32ri_prec_ftz $x, fpimm:$y), CvtRZI_FTZ),
1387- fpimm:$y))>,
1388- Requires<[doF32FTZ, allowUnsafeFPMath]>;
1389-
1390- def : Pat<(frem f32:$x, f32:$y),
1391- (SELP_f32rr $x,
1392- (FSUBf32rr_ftz $x, (FMULf32rr_ftz (CVT_f32_f32
1393- (FDIV32rr_prec_ftz $x, $y), CvtRZI_FTZ),
1394- $y)),
1395- (TESTINF_f32r $y))>,
1396- Requires<[doF32FTZ, noUnsafeFPMath]>;
1397- def : Pat<(frem f32:$x, fpimm:$y),
1398- (SELP_f32rr $x,
1399- (FSUBf32rr_ftz $x, (FMULf32ri_ftz (CVT_f32_f32
1400- (FDIV32ri_prec_ftz $x, fpimm:$y), CvtRZI_FTZ),
1401- fpimm:$y)),
1402- (TESTINF_f32i fpimm:$y))>,
1403- Requires<[doF32FTZ, noUnsafeFPMath]>;
1404-
1405- // frem - f32
1406- def : Pat<(frem f32:$x, f32:$y),
1407- (FSUBf32rr $x, (FMULf32rr (CVT_f32_f32
1408- (FDIV32rr_prec $x, $y), CvtRZI),
1409- $y))>,
1410- Requires<[allowUnsafeFPMath]>;
1411- def : Pat<(frem f32:$x, fpimm:$y),
1412- (FSUBf32rr $x, (FMULf32ri (CVT_f32_f32
1413- (FDIV32ri_prec $x, fpimm:$y), CvtRZI),
1414- fpimm:$y))>,
1415- Requires<[allowUnsafeFPMath]>;
1416-
1417- def : Pat<(frem f32:$x, f32:$y),
1418- (SELP_f32rr $x,
1419- (FSUBf32rr $x, (FMULf32rr (CVT_f32_f32
1420- (FDIV32rr_prec $x, $y), CvtRZI),
1421- $y)),
1422- (TESTINF_f32r Float32Regs:$y))>,
1423- Requires<[noUnsafeFPMath]>;
1424- def : Pat<(frem f32:$x, fpimm:$y),
1425- (SELP_f32rr $x,
1426- (FSUBf32rr $x, (FMULf32ri (CVT_f32_f32
1427- (FDIV32ri_prec $x, fpimm:$y), CvtRZI),
1428- fpimm:$y)),
1429- (TESTINF_f32i fpimm:$y))>,
1430- Requires<[noUnsafeFPMath]>;
1431-
1432- // frem - f64
1433- def : Pat<(frem f64:$x, f64:$y),
1434- (FSUBf64rr $x, (FMULf64rr (CVT_f64_f64
1435- (FDIV64rr $x, $y), CvtRZI),
1436- $y))>,
1437- Requires<[allowUnsafeFPMath]>;
1438- def : Pat<(frem f64:$x, fpimm:$y),
1439- (FSUBf64rr $x, (FMULf64ri (CVT_f64_f64
1440- (FDIV64ri $x, fpimm:$y), CvtRZI),
1441- fpimm:$y))>,
1442- Requires<[allowUnsafeFPMath]>;
1443-
1444- def : Pat<(frem f64:$x, f64:$y),
1445- (SELP_f64rr $x,
1446- (FSUBf64rr $x, (FMULf64rr (CVT_f64_f64
1447- (FDIV64rr $x, $y), CvtRZI),
1448- $y)),
1449- (TESTINF_f64r Float64Regs:$y))>,
1450- Requires<[noUnsafeFPMath]>;
1451- def : Pat<(frem f64:$x, fpimm:$y),
1452- (SELP_f64rr $x,
1453- (FSUBf64rr $x, (FMULf64ri (CVT_f64_f64
1454- (FDIV64ri $x, fpimm:$y), CvtRZI),
1455- fpimm:$y)),
1456- (TESTINF_f64r $y))>,
1457- Requires<[noUnsafeFPMath]>;
1373+ [(set f32:$dst, (UnaryOpAllowsApproxFn<fcos> f32:$src))]>;
14581374
14591375//-----------------------------------
14601376// Bitwise operations
0 commit comments