@@ -150,9 +150,6 @@ def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
150
150
151
151
def doMulWide : Predicate<"doMulWide">;
152
152
153
- def allowUnsafeFPMath : Predicate<"allowUnsafeFPMath()">;
154
- def noUnsafeFPMath : Predicate<"!allowUnsafeFPMath()">;
155
-
156
153
def do_DIVF32_APPROX : Predicate<"getDivF32Level()==0">;
157
154
def do_DIVF32_FULL : Predicate<"getDivF32Level()==1">;
158
155
@@ -211,6 +208,12 @@ class ValueToRegClass<ValueType T> {
211
208
// Some Common Instruction Class Templates
212
209
//===----------------------------------------------------------------------===//
213
210
211
+ class OneUse1<SDPatternOperator operator>
212
+ : PatFrag<(ops node:$A), (operator node:$A), [{ return N->hasOneUse(); }]>;
213
+
214
+ class fpimm_pos_inf<ValueType vt>
215
+ : FPImmLeaf<vt, [{ return Imm.isPosInfinity(); }]>;
216
+
214
217
// Utility class to wrap up information about a register and DAG type for more
215
218
// convenient iteration and parameterization
216
219
class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm> {
@@ -442,7 +445,7 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
442
445
class BinOpAllowsFMA<SDPatternOperator operator>
443
446
: PatFrag<(ops node:$A, node:$B),
444
447
(operator node:$A, node:$B), [{
445
- return allowFMA() || N->getFlags().hasAllowContract();;
448
+ return allowFMA() || N->getFlags().hasAllowContract();
446
449
}]>;
447
450
448
451
multiclass F3_fma_component<string op_str, SDNode op_node> {
@@ -693,10 +696,7 @@ let hasSideEffects = false in {
693
696
defm CVT_to_tf32_rz_relu_satf : CVT_TO_TF32<"rz.relu.satfinite", [hasPTX<86>, hasSM<100>]>;
694
697
}
695
698
696
- def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{
697
- return N->hasOneUse();
698
- }]>;
699
-
699
+ def fpround_oneuse : OneUse1<fpround>;
700
700
def : Pat<(v2bf16 (build_vector (bf16 (fpround_oneuse f32:$lo)),
701
701
(bf16 (fpround_oneuse f32:$hi)))),
702
702
(CVT_bf16x2_f32 $hi, $lo, CvtRN)>,
@@ -786,18 +786,14 @@ def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
786
786
// Test Instructions
787
787
//-----------------------------------
788
788
789
+ def fabs_oneuse : OneUse1<fabs>;
790
+
789
791
def TESTINF_f32r : NVPTXInst<(outs Int1Regs:$p), (ins Float32Regs:$a),
790
792
"testp.infinite.f32 \t$p, $a;",
791
- []>;
792
- def TESTINF_f32i : NVPTXInst<(outs Int1Regs:$p), (ins f32imm:$a),
793
- "testp.infinite.f32 \t$p, $a;",
794
- []>;
793
+ [(set i1:$p, (seteq (fabs_oneuse f32:$a), fpimm_pos_inf<f32>))]>;
795
794
def TESTINF_f64r : NVPTXInst<(outs Int1Regs:$p), (ins Float64Regs:$a),
796
795
"testp.infinite.f64 \t$p, $a;",
797
- []>;
798
- def TESTINF_f64i : NVPTXInst<(outs Int1Regs:$p), (ins f64imm:$a),
799
- "testp.infinite.f64 \t$p, $a;",
800
- []>;
796
+ [(set i1:$p, (seteq (fabs_oneuse f64:$a), fpimm_pos_inf<f64>))]>;
801
797
802
798
//-----------------------------------
803
799
// Integer Arithmetic
@@ -1362,99 +1358,19 @@ defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
1362
1358
defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
1363
1359
1364
1360
// sin/cos
1361
+
1362
+ class UnaryOpAllowsApproxFn<SDPatternOperator operator>
1363
+ : PatFrag<(ops node:$A),
1364
+ (operator node:$A), [{
1365
+ return allowUnsafeFPMath() || N->getFlags().hasApproximateFuncs();
1366
+ }]>;
1367
+
1365
1368
def SINF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
1366
1369
"sin.approx.f32 \t$dst, $src;",
1367
- [(set f32:$dst, (fsin f32:$src))]>,
1368
- Requires<[allowUnsafeFPMath]>;
1370
+ [(set f32:$dst, (UnaryOpAllowsApproxFn<fsin> f32:$src))]>;
1369
1371
def COSF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
1370
1372
"cos.approx.f32 \t$dst, $src;",
1371
- [(set f32:$dst, (fcos f32:$src))]>,
1372
- Requires<[allowUnsafeFPMath]>;
1373
-
1374
- // Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
1375
- // i.e. "poor man's fmod()". When y is infinite, x is returned. This matches the
1376
- // semantics of LLVM's frem.
1377
-
1378
- // frem - f32 FTZ
1379
- def : Pat<(frem f32:$x, f32:$y),
1380
- (FSUBf32rr_ftz $x, (FMULf32rr_ftz (CVT_f32_f32
1381
- (FDIV32rr_prec_ftz $x, $y), CvtRZI_FTZ),
1382
- $y))>,
1383
- Requires<[doF32FTZ, allowUnsafeFPMath]>;
1384
- def : Pat<(frem f32:$x, fpimm:$y),
1385
- (FSUBf32rr_ftz $x, (FMULf32ri_ftz (CVT_f32_f32
1386
- (FDIV32ri_prec_ftz $x, fpimm:$y), CvtRZI_FTZ),
1387
- fpimm:$y))>,
1388
- Requires<[doF32FTZ, allowUnsafeFPMath]>;
1389
-
1390
- def : Pat<(frem f32:$x, f32:$y),
1391
- (SELP_f32rr $x,
1392
- (FSUBf32rr_ftz $x, (FMULf32rr_ftz (CVT_f32_f32
1393
- (FDIV32rr_prec_ftz $x, $y), CvtRZI_FTZ),
1394
- $y)),
1395
- (TESTINF_f32r $y))>,
1396
- Requires<[doF32FTZ, noUnsafeFPMath]>;
1397
- def : Pat<(frem f32:$x, fpimm:$y),
1398
- (SELP_f32rr $x,
1399
- (FSUBf32rr_ftz $x, (FMULf32ri_ftz (CVT_f32_f32
1400
- (FDIV32ri_prec_ftz $x, fpimm:$y), CvtRZI_FTZ),
1401
- fpimm:$y)),
1402
- (TESTINF_f32i fpimm:$y))>,
1403
- Requires<[doF32FTZ, noUnsafeFPMath]>;
1404
-
1405
- // frem - f32
1406
- def : Pat<(frem f32:$x, f32:$y),
1407
- (FSUBf32rr $x, (FMULf32rr (CVT_f32_f32
1408
- (FDIV32rr_prec $x, $y), CvtRZI),
1409
- $y))>,
1410
- Requires<[allowUnsafeFPMath]>;
1411
- def : Pat<(frem f32:$x, fpimm:$y),
1412
- (FSUBf32rr $x, (FMULf32ri (CVT_f32_f32
1413
- (FDIV32ri_prec $x, fpimm:$y), CvtRZI),
1414
- fpimm:$y))>,
1415
- Requires<[allowUnsafeFPMath]>;
1416
-
1417
- def : Pat<(frem f32:$x, f32:$y),
1418
- (SELP_f32rr $x,
1419
- (FSUBf32rr $x, (FMULf32rr (CVT_f32_f32
1420
- (FDIV32rr_prec $x, $y), CvtRZI),
1421
- $y)),
1422
- (TESTINF_f32r Float32Regs:$y))>,
1423
- Requires<[noUnsafeFPMath]>;
1424
- def : Pat<(frem f32:$x, fpimm:$y),
1425
- (SELP_f32rr $x,
1426
- (FSUBf32rr $x, (FMULf32ri (CVT_f32_f32
1427
- (FDIV32ri_prec $x, fpimm:$y), CvtRZI),
1428
- fpimm:$y)),
1429
- (TESTINF_f32i fpimm:$y))>,
1430
- Requires<[noUnsafeFPMath]>;
1431
-
1432
- // frem - f64
1433
- def : Pat<(frem f64:$x, f64:$y),
1434
- (FSUBf64rr $x, (FMULf64rr (CVT_f64_f64
1435
- (FDIV64rr $x, $y), CvtRZI),
1436
- $y))>,
1437
- Requires<[allowUnsafeFPMath]>;
1438
- def : Pat<(frem f64:$x, fpimm:$y),
1439
- (FSUBf64rr $x, (FMULf64ri (CVT_f64_f64
1440
- (FDIV64ri $x, fpimm:$y), CvtRZI),
1441
- fpimm:$y))>,
1442
- Requires<[allowUnsafeFPMath]>;
1443
-
1444
- def : Pat<(frem f64:$x, f64:$y),
1445
- (SELP_f64rr $x,
1446
- (FSUBf64rr $x, (FMULf64rr (CVT_f64_f64
1447
- (FDIV64rr $x, $y), CvtRZI),
1448
- $y)),
1449
- (TESTINF_f64r Float64Regs:$y))>,
1450
- Requires<[noUnsafeFPMath]>;
1451
- def : Pat<(frem f64:$x, fpimm:$y),
1452
- (SELP_f64rr $x,
1453
- (FSUBf64rr $x, (FMULf64ri (CVT_f64_f64
1454
- (FDIV64ri $x, fpimm:$y), CvtRZI),
1455
- fpimm:$y)),
1456
- (TESTINF_f64r $y))>,
1457
- Requires<[noUnsafeFPMath]>;
1373
+ [(set f32:$dst, (UnaryOpAllowsApproxFn<fcos> f32:$src))]>;
1458
1374
1459
1375
//-----------------------------------
1460
1376
// Bitwise operations
0 commit comments