Skip to content

Commit 026166f

Browse files
committed
Add constant-folding for unary NVVM intrinsics
Add support for constant-folding numerous NVVM unary arithmetic intrinsics (including f, d, and ftz_f variants): - nvvm.ceil.* - nvvm.cos.approx.* - nvvm.ex2.approx.* - nvvm.fabs.* - nvvm.floor.* - nvvm.lg2.approx.* - nvvm.rcp.* - nvvm.round.* - nvvm.rsqrt.approx.* - nvvm.saturate.* - nvvm.sin.approx.* - nvvm.sqrt.f - nvvm.sqrt.rn.* - nvvm.sqrt.approx.*
1 parent 7521ce9 commit 026166f

File tree

3 files changed

+1329
-4
lines changed

3 files changed

+1329
-4
lines changed

llvm/include/llvm/IR/NVVMIntrinsicUtils.h

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,127 @@ inline bool FMinFMaxIsXorSignAbs(Intrinsic::ID IntrinsicID) {
334334
return false;
335335
}
336336

337+
inline bool UnaryMathIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
338+
switch (IntrinsicID) {
339+
case Intrinsic::nvvm_ceil_ftz_f:
340+
case Intrinsic::nvvm_cos_approx_ftz_f:
341+
case Intrinsic::nvvm_ex2_approx_ftz_f:
342+
case Intrinsic::nvvm_fabs_ftz:
343+
case Intrinsic::nvvm_floor_ftz_f:
344+
case Intrinsic::nvvm_lg2_approx_ftz_f:
345+
case Intrinsic::nvvm_round_ftz_f:
346+
case Intrinsic::nvvm_rsqrt_approx_ftz_d:
347+
case Intrinsic::nvvm_rsqrt_approx_ftz_f:
348+
case Intrinsic::nvvm_saturate_ftz_f:
349+
case Intrinsic::nvvm_sin_approx_ftz_f:
350+
case Intrinsic::nvvm_sqrt_rn_ftz_f:
351+
case Intrinsic::nvvm_sqrt_approx_ftz_f:
352+
return true;
353+
case Intrinsic::nvvm_ceil_f:
354+
case Intrinsic::nvvm_ceil_d:
355+
case Intrinsic::nvvm_cos_approx_f:
356+
case Intrinsic::nvvm_ex2_approx_d:
357+
case Intrinsic::nvvm_ex2_approx_f:
358+
case Intrinsic::nvvm_fabs:
359+
case Intrinsic::nvvm_floor_f:
360+
case Intrinsic::nvvm_floor_d:
361+
case Intrinsic::nvvm_lg2_approx_d:
362+
case Intrinsic::nvvm_lg2_approx_f:
363+
case Intrinsic::nvvm_round_f:
364+
case Intrinsic::nvvm_round_d:
365+
case Intrinsic::nvvm_rsqrt_approx_d:
366+
case Intrinsic::nvvm_rsqrt_approx_f:
367+
case Intrinsic::nvvm_saturate_d:
368+
case Intrinsic::nvvm_saturate_f:
369+
case Intrinsic::nvvm_sin_approx_f:
370+
case Intrinsic::nvvm_sqrt_f:
371+
case Intrinsic::nvvm_sqrt_rn_d:
372+
case Intrinsic::nvvm_sqrt_rn_f:
373+
case Intrinsic::nvvm_sqrt_approx_f:
374+
return false;
375+
}
376+
llvm_unreachable("Checking FTZ flag for invalid unary intrinsic");
377+
return false;
378+
}
379+
380+
inline bool RCPShouldFTZ(Intrinsic::ID IntrinsicID) {
381+
switch (IntrinsicID) {
382+
case Intrinsic::nvvm_rcp_rm_ftz_f:
383+
case Intrinsic::nvvm_rcp_rn_ftz_f:
384+
case Intrinsic::nvvm_rcp_rp_ftz_f:
385+
case Intrinsic::nvvm_rcp_rz_ftz_f:
386+
case Intrinsic::nvvm_rcp_approx_ftz_f:
387+
case Intrinsic::nvvm_rcp_approx_ftz_d:
388+
return true;
389+
case Intrinsic::nvvm_rcp_rm_d:
390+
case Intrinsic::nvvm_rcp_rm_f:
391+
case Intrinsic::nvvm_rcp_rn_d:
392+
case Intrinsic::nvvm_rcp_rn_f:
393+
case Intrinsic::nvvm_rcp_rp_d:
394+
case Intrinsic::nvvm_rcp_rp_f:
395+
case Intrinsic::nvvm_rcp_rz_d:
396+
case Intrinsic::nvvm_rcp_rz_f:
397+
return false;
398+
}
399+
llvm_unreachable("Checking FTZ flag for invalid rcp intrinsic");
400+
return false;
401+
}
402+
403+
inline APFloat::roundingMode GetRCPRoundingMode(Intrinsic::ID IntrinsicID) {
404+
switch (IntrinsicID) {
405+
case Intrinsic::nvvm_rcp_rm_f:
406+
case Intrinsic::nvvm_rcp_rm_d:
407+
case Intrinsic::nvvm_rcp_rm_ftz_f:
408+
return APFloat::rmTowardNegative;
409+
410+
case Intrinsic::nvvm_rcp_approx_ftz_f:
411+
case Intrinsic::nvvm_rcp_approx_ftz_d:
412+
case Intrinsic::nvvm_rcp_rn_f:
413+
case Intrinsic::nvvm_rcp_rn_d:
414+
case Intrinsic::nvvm_rcp_rn_ftz_f:
415+
return APFloat::rmNearestTiesToEven;
416+
417+
case Intrinsic::nvvm_rcp_rp_f:
418+
case Intrinsic::nvvm_rcp_rp_d:
419+
case Intrinsic::nvvm_rcp_rp_ftz_f:
420+
return APFloat::rmNearestTiesToEven;
421+
422+
case Intrinsic::nvvm_rcp_rz_f:
423+
case Intrinsic::nvvm_rcp_rz_d:
424+
case Intrinsic::nvvm_rcp_rz_ftz_f:
425+
return APFloat::rmTowardZero;
426+
}
427+
llvm_unreachable("Checking rounding mode for invalid rcp intrinsic");
428+
return APFloat::roundingMode::Invalid;
429+
}
430+
431+
inline bool RCPIsApprox(Intrinsic::ID IntrinsicID) {
432+
switch (IntrinsicID) {
433+
case Intrinsic::nvvm_rcp_approx_ftz_f:
434+
case Intrinsic::nvvm_rcp_approx_ftz_d:
435+
return true;
436+
437+
case Intrinsic::nvvm_rcp_rm_f:
438+
case Intrinsic::nvvm_rcp_rm_d:
439+
case Intrinsic::nvvm_rcp_rm_ftz_f:
440+
441+
case Intrinsic::nvvm_rcp_rn_f:
442+
case Intrinsic::nvvm_rcp_rn_d:
443+
case Intrinsic::nvvm_rcp_rn_ftz_f:
444+
445+
case Intrinsic::nvvm_rcp_rp_f:
446+
case Intrinsic::nvvm_rcp_rp_d:
447+
case Intrinsic::nvvm_rcp_rp_ftz_f:
448+
449+
case Intrinsic::nvvm_rcp_rz_f:
450+
case Intrinsic::nvvm_rcp_rz_d:
451+
case Intrinsic::nvvm_rcp_rz_ftz_f:
452+
return false;
453+
}
454+
llvm_unreachable("Checking approx flag for invalid rcp intrinsic");
455+
return false;
456+
}
457+
337458
} // namespace nvvm
338459
} // namespace llvm
339460
#endif // LLVM_IR_NVVMINTRINSICUTILS_H

llvm/lib/Analysis/ConstantFolding.cpp

Lines changed: 205 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,6 +1776,67 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
17761776
case Intrinsic::nvvm_d2ull_rp:
17771777
case Intrinsic::nvvm_d2ull_rz:
17781778

1779+
// NVVM math intrinsics:
1780+
case Intrinsic::nvvm_ceil_d:
1781+
case Intrinsic::nvvm_ceil_f:
1782+
case Intrinsic::nvvm_ceil_ftz_f:
1783+
1784+
case Intrinsic::nvvm_cos_approx_f:
1785+
case Intrinsic::nvvm_cos_approx_ftz_f:
1786+
1787+
case Intrinsic::nvvm_ex2_approx_d:
1788+
case Intrinsic::nvvm_ex2_approx_f:
1789+
case Intrinsic::nvvm_ex2_approx_ftz_f:
1790+
1791+
case Intrinsic::nvvm_fabs:
1792+
case Intrinsic::nvvm_fabs_ftz:
1793+
1794+
case Intrinsic::nvvm_floor_d:
1795+
case Intrinsic::nvvm_floor_f:
1796+
case Intrinsic::nvvm_floor_ftz_f:
1797+
1798+
case Intrinsic::nvvm_lg2_approx_d:
1799+
case Intrinsic::nvvm_lg2_approx_f:
1800+
case Intrinsic::nvvm_lg2_approx_ftz_f:
1801+
1802+
case Intrinsic::nvvm_rcp_rm_d:
1803+
case Intrinsic::nvvm_rcp_rm_f:
1804+
case Intrinsic::nvvm_rcp_rm_ftz_f:
1805+
case Intrinsic::nvvm_rcp_rn_d:
1806+
case Intrinsic::nvvm_rcp_rn_f:
1807+
case Intrinsic::nvvm_rcp_rn_ftz_f:
1808+
case Intrinsic::nvvm_rcp_rp_d:
1809+
case Intrinsic::nvvm_rcp_rp_f:
1810+
case Intrinsic::nvvm_rcp_rp_ftz_f:
1811+
case Intrinsic::nvvm_rcp_rz_d:
1812+
case Intrinsic::nvvm_rcp_rz_f:
1813+
case Intrinsic::nvvm_rcp_rz_ftz_f:
1814+
case Intrinsic::nvvm_rcp_approx_ftz_d:
1815+
case Intrinsic::nvvm_rcp_approx_ftz_f:
1816+
1817+
case Intrinsic::nvvm_round_d:
1818+
case Intrinsic::nvvm_round_f:
1819+
case Intrinsic::nvvm_round_ftz_f:
1820+
1821+
case Intrinsic::nvvm_rsqrt_approx_d:
1822+
case Intrinsic::nvvm_rsqrt_approx_f:
1823+
case Intrinsic::nvvm_rsqrt_approx_ftz_d:
1824+
case Intrinsic::nvvm_rsqrt_approx_ftz_f:
1825+
1826+
case Intrinsic::nvvm_saturate_d:
1827+
case Intrinsic::nvvm_saturate_f:
1828+
case Intrinsic::nvvm_saturate_ftz_f:
1829+
1830+
case Intrinsic::nvvm_sin_approx_f:
1831+
case Intrinsic::nvvm_sin_approx_ftz_f:
1832+
1833+
case Intrinsic::nvvm_sqrt_f:
1834+
case Intrinsic::nvvm_sqrt_rn_d:
1835+
case Intrinsic::nvvm_sqrt_rn_f:
1836+
case Intrinsic::nvvm_sqrt_rn_ftz_f:
1837+
case Intrinsic::nvvm_sqrt_approx_f:
1838+
case Intrinsic::nvvm_sqrt_approx_ftz_f:
1839+
17791840
// Sign operations are actually bitwise operations, they do not raise
17801841
// exceptions even for SNANs.
17811842
case Intrinsic::fabs:
@@ -1791,6 +1852,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
17911852
case Intrinsic::nearbyint:
17921853
case Intrinsic::rint:
17931854
case Intrinsic::canonicalize:
1855+
17941856
// Constrained intrinsics can be folded if FP environment is known
17951857
// to compiler.
17961858
case Intrinsic::experimental_constrained_fma:
@@ -1944,16 +2006,23 @@ static const APFloat FTZPreserveSign(const APFloat &V) {
19442006
return V;
19452007
}
19462008

1947-
Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V,
1948-
Type *Ty) {
2009+
Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V, Type *Ty,
2010+
bool ShouldFTZPreservingSign = false) {
19492011
llvm_fenv_clearexcept();
1950-
double Result = NativeFP(V.convertToDouble());
2012+
auto Input = ShouldFTZPreservingSign ? FTZPreserveSign(V) : V;
2013+
double Result = NativeFP(Input.convertToDouble());
19512014
if (llvm_fenv_testexcept()) {
19522015
llvm_fenv_clearexcept();
19532016
return nullptr;
19542017
}
19552018

1956-
return GetConstantFoldFPValue(Result, Ty);
2019+
Constant *Output = GetConstantFoldFPValue(Result, Ty);
2020+
if (ShouldFTZPreservingSign) {
2021+
const auto *CFP = static_cast<ConstantFP *>(Output);
2022+
return ConstantFP::get(Ty->getContext(),
2023+
FTZPreserveSign(CFP->getValueAPF()));
2024+
}
2025+
return Output;
19572026
}
19582027

19592028
#if defined(HAS_IEE754_FLOAT128) && defined(HAS_LOGF128)
@@ -2524,6 +2593,138 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
25242593
return ConstantFoldFP(cosh, APF, Ty);
25252594
case Intrinsic::sqrt:
25262595
return ConstantFoldFP(sqrt, APF, Ty);
2596+
2597+
// NVVM Intrinsics:
2598+
case Intrinsic::nvvm_ceil_ftz_f:
2599+
case Intrinsic::nvvm_ceil_f:
2600+
case Intrinsic::nvvm_ceil_d:
2601+
return ConstantFoldFP(ceil, APF, Ty,
2602+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2603+
2604+
case Intrinsic::nvvm_cos_approx_ftz_f:
2605+
case Intrinsic::nvvm_cos_approx_f:
2606+
return ConstantFoldFP(cos, APF, Ty,
2607+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2608+
2609+
case Intrinsic::nvvm_ex2_approx_ftz_f:
2610+
case Intrinsic::nvvm_ex2_approx_d:
2611+
case Intrinsic::nvvm_ex2_approx_f:
2612+
return ConstantFoldFP(exp2, APF, Ty,
2613+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2614+
2615+
case Intrinsic::nvvm_fabs_ftz:
2616+
case Intrinsic::nvvm_fabs:
2617+
return ConstantFoldFP(fabs, APF, Ty,
2618+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2619+
2620+
case Intrinsic::nvvm_floor_ftz_f:
2621+
case Intrinsic::nvvm_floor_f:
2622+
case Intrinsic::nvvm_floor_d:
2623+
return ConstantFoldFP(floor, APF, Ty,
2624+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2625+
2626+
case Intrinsic::nvvm_lg2_approx_ftz_f:
2627+
case Intrinsic::nvvm_lg2_approx_d:
2628+
case Intrinsic::nvvm_lg2_approx_f: {
2629+
if (APF.isNegative() || APF.isZero())
2630+
return nullptr;
2631+
return ConstantFoldFP(log2, APF, Ty,
2632+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2633+
}
2634+
2635+
case Intrinsic::nvvm_rcp_rm_ftz_f:
2636+
case Intrinsic::nvvm_rcp_rn_ftz_f:
2637+
case Intrinsic::nvvm_rcp_rp_ftz_f:
2638+
case Intrinsic::nvvm_rcp_rz_ftz_f:
2639+
case Intrinsic::nvvm_rcp_approx_ftz_f:
2640+
case Intrinsic::nvvm_rcp_approx_ftz_d:
2641+
case Intrinsic::nvvm_rcp_rm_d:
2642+
case Intrinsic::nvvm_rcp_rm_f:
2643+
case Intrinsic::nvvm_rcp_rn_d:
2644+
case Intrinsic::nvvm_rcp_rn_f:
2645+
case Intrinsic::nvvm_rcp_rp_d:
2646+
case Intrinsic::nvvm_rcp_rp_f:
2647+
case Intrinsic::nvvm_rcp_rz_d:
2648+
case Intrinsic::nvvm_rcp_rz_f: {
2649+
APFloat::roundingMode RoundMode = nvvm::GetRCPRoundingMode(IntrinsicID);
2650+
bool IsApprox = nvvm::RCPIsApprox(IntrinsicID);
2651+
bool IsFTZ = nvvm::RCPShouldFTZ(IntrinsicID);
2652+
2653+
auto Denominator = IsFTZ ? FTZPreserveSign(APF) : APF;
2654+
if (IsApprox && Denominator.isZero()) {
2655+
// According to the PTX spec, approximate rcp should return infinity
2656+
// with the same sign as the denominator when dividing by 0.
2657+
APFloat Inf = APFloat::getInf(APF.getSemantics(), APF.isNegative());
2658+
return ConstantFP::get(Ty->getContext(), Inf);
2659+
}
2660+
APFloat Res = APFloat::getOne(APF.getSemantics());
2661+
APFloat::opStatus Status = Res.divide(Denominator, RoundMode);
2662+
2663+
if (Status == APFloat::opOK || Status == APFloat::opInexact) {
2664+
if (IsFTZ)
2665+
Res = FTZPreserveSign(Res);
2666+
return ConstantFP::get(Ty->getContext(), Res);
2667+
}
2668+
return nullptr;
2669+
}
2670+
2671+
case Intrinsic::nvvm_round_ftz_f:
2672+
case Intrinsic::nvvm_round_f:
2673+
case Intrinsic::nvvm_round_d:
2674+
return ConstantFoldFP(round, APF, Ty,
2675+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2676+
2677+
case Intrinsic::nvvm_rsqrt_approx_ftz_d:
2678+
case Intrinsic::nvvm_rsqrt_approx_ftz_f:
2679+
case Intrinsic::nvvm_rsqrt_approx_d:
2680+
case Intrinsic::nvvm_rsqrt_approx_f: {
2681+
bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID);
2682+
auto V = IsFTZ ? FTZPreserveSign(APF) : APF;
2683+
APFloat SqrtV(sqrt(V.convertToDouble()));
2684+
2685+
bool lost;
2686+
SqrtV.convert(APF.getSemantics(), APFloat::rmNearestTiesToEven, &lost);
2687+
2688+
APFloat Res = APFloat::getOne(APF.getSemantics());
2689+
Res.divide(SqrtV, APFloat::rmNearestTiesToEven);
2690+
2691+
// We do not need to flush the output for ftz because it is impossible
2692+
// for 1/sqrt(x) to be a denormal value. If x is the largest fp value,
2693+
// sqrt(x) will be a number with the exponent approximately halved and
2694+
// the reciprocal of that number can't be small enough to be denormal.
2695+
return ConstantFP::get(Ty->getContext(), Res);
2696+
}
2697+
2698+
case Intrinsic::nvvm_saturate_ftz_f:
2699+
case Intrinsic::nvvm_saturate_d:
2700+
case Intrinsic::nvvm_saturate_f: {
2701+
bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID);
2702+
auto V = IsFTZ ? FTZPreserveSign(APF) : APF;
2703+
if (V.isNegative() || V.isZero() || V.isNaN())
2704+
return ConstantFP::getZero(Ty);
2705+
APFloat One = APFloat::getOne(APF.getSemantics());
2706+
if (V > One)
2707+
return ConstantFP::get(Ty->getContext(), One);
2708+
return ConstantFP::get(Ty->getContext(), APF);
2709+
}
2710+
2711+
case Intrinsic::nvvm_sin_approx_ftz_f:
2712+
case Intrinsic::nvvm_sin_approx_f:
2713+
return ConstantFoldFP(sin, APF, Ty,
2714+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2715+
2716+
case Intrinsic::nvvm_sqrt_rn_ftz_f:
2717+
case Intrinsic::nvvm_sqrt_approx_ftz_f:
2718+
case Intrinsic::nvvm_sqrt_f:
2719+
case Intrinsic::nvvm_sqrt_rn_d:
2720+
case Intrinsic::nvvm_sqrt_rn_f:
2721+
case Intrinsic::nvvm_sqrt_approx_f:
2722+
if (APF.isNegative())
2723+
return nullptr;
2724+
return ConstantFoldFP(sqrt, APF, Ty,
2725+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
2726+
2727+
// AMDGCN Intrinsics:
25272728
case Intrinsic::amdgcn_cos:
25282729
case Intrinsic::amdgcn_sin: {
25292730
double V = getValueAsDouble(Op);

0 commit comments

Comments
 (0)