Skip to content
121 changes: 121 additions & 0 deletions llvm/include/llvm/IR/NVVMIntrinsicUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,127 @@ inline bool FMinFMaxIsXorSignAbs(Intrinsic::ID IntrinsicID) {
return false;
}

inline bool UnaryMathIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
case Intrinsic::nvvm_ceil_ftz_f:
case Intrinsic::nvvm_cos_approx_ftz_f:
case Intrinsic::nvvm_ex2_approx_ftz_f:
case Intrinsic::nvvm_fabs_ftz:
case Intrinsic::nvvm_floor_ftz_f:
case Intrinsic::nvvm_lg2_approx_ftz_f:
case Intrinsic::nvvm_round_ftz_f:
case Intrinsic::nvvm_rsqrt_approx_ftz_d:
case Intrinsic::nvvm_rsqrt_approx_ftz_f:
case Intrinsic::nvvm_saturate_ftz_f:
case Intrinsic::nvvm_sin_approx_ftz_f:
case Intrinsic::nvvm_sqrt_rn_ftz_f:
case Intrinsic::nvvm_sqrt_approx_ftz_f:
return true;
case Intrinsic::nvvm_ceil_f:
case Intrinsic::nvvm_ceil_d:
case Intrinsic::nvvm_cos_approx_f:
case Intrinsic::nvvm_ex2_approx_d:
case Intrinsic::nvvm_ex2_approx_f:
case Intrinsic::nvvm_fabs:
case Intrinsic::nvvm_floor_f:
case Intrinsic::nvvm_floor_d:
case Intrinsic::nvvm_lg2_approx_d:
case Intrinsic::nvvm_lg2_approx_f:
case Intrinsic::nvvm_round_f:
case Intrinsic::nvvm_round_d:
case Intrinsic::nvvm_rsqrt_approx_d:
case Intrinsic::nvvm_rsqrt_approx_f:
case Intrinsic::nvvm_saturate_d:
case Intrinsic::nvvm_saturate_f:
case Intrinsic::nvvm_sin_approx_f:
case Intrinsic::nvvm_sqrt_f:
case Intrinsic::nvvm_sqrt_rn_d:
case Intrinsic::nvvm_sqrt_rn_f:
case Intrinsic::nvvm_sqrt_approx_f:
return false;
}
llvm_unreachable("Checking FTZ flag for invalid unary intrinsic");
return false;
}

inline bool RCPShouldFTZ(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
case Intrinsic::nvvm_rcp_rm_ftz_f:
case Intrinsic::nvvm_rcp_rn_ftz_f:
case Intrinsic::nvvm_rcp_rp_ftz_f:
case Intrinsic::nvvm_rcp_rz_ftz_f:
case Intrinsic::nvvm_rcp_approx_ftz_f:
case Intrinsic::nvvm_rcp_approx_ftz_d:
return true;
case Intrinsic::nvvm_rcp_rm_d:
case Intrinsic::nvvm_rcp_rm_f:
case Intrinsic::nvvm_rcp_rn_d:
case Intrinsic::nvvm_rcp_rn_f:
case Intrinsic::nvvm_rcp_rp_d:
case Intrinsic::nvvm_rcp_rp_f:
case Intrinsic::nvvm_rcp_rz_d:
case Intrinsic::nvvm_rcp_rz_f:
return false;
}
llvm_unreachable("Checking FTZ flag for invalid rcp intrinsic");
return false;
}

inline APFloat::roundingMode GetRCPRoundingMode(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
case Intrinsic::nvvm_rcp_rm_f:
case Intrinsic::nvvm_rcp_rm_d:
case Intrinsic::nvvm_rcp_rm_ftz_f:
return APFloat::rmTowardNegative;

case Intrinsic::nvvm_rcp_approx_ftz_f:
case Intrinsic::nvvm_rcp_approx_ftz_d:
case Intrinsic::nvvm_rcp_rn_f:
case Intrinsic::nvvm_rcp_rn_d:
case Intrinsic::nvvm_rcp_rn_ftz_f:
return APFloat::rmNearestTiesToEven;

case Intrinsic::nvvm_rcp_rp_f:
case Intrinsic::nvvm_rcp_rp_d:
case Intrinsic::nvvm_rcp_rp_ftz_f:
return APFloat::rmNearestTiesToEven;

case Intrinsic::nvvm_rcp_rz_f:
case Intrinsic::nvvm_rcp_rz_d:
case Intrinsic::nvvm_rcp_rz_ftz_f:
return APFloat::rmTowardZero;
}
llvm_unreachable("Checking rounding mode for invalid rcp intrinsic");
return APFloat::roundingMode::Invalid;
}

inline bool RCPIsApprox(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
case Intrinsic::nvvm_rcp_approx_ftz_f:
case Intrinsic::nvvm_rcp_approx_ftz_d:
return true;

case Intrinsic::nvvm_rcp_rm_f:
case Intrinsic::nvvm_rcp_rm_d:
case Intrinsic::nvvm_rcp_rm_ftz_f:

case Intrinsic::nvvm_rcp_rn_f:
case Intrinsic::nvvm_rcp_rn_d:
case Intrinsic::nvvm_rcp_rn_ftz_f:

case Intrinsic::nvvm_rcp_rp_f:
case Intrinsic::nvvm_rcp_rp_d:
case Intrinsic::nvvm_rcp_rp_ftz_f:

case Intrinsic::nvvm_rcp_rz_f:
case Intrinsic::nvvm_rcp_rz_d:
case Intrinsic::nvvm_rcp_rz_ftz_f:
return false;
}
llvm_unreachable("Checking approx flag for invalid rcp intrinsic");
return false;
}

} // namespace nvvm
} // namespace llvm
#endif // LLVM_IR_NVVMINTRINSICUTILS_H
232 changes: 228 additions & 4 deletions llvm/lib/Analysis/ConstantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,67 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::nvvm_d2ull_rp:
case Intrinsic::nvvm_d2ull_rz:

// NVVM math intrinsics:
case Intrinsic::nvvm_ceil_d:
case Intrinsic::nvvm_ceil_f:
case Intrinsic::nvvm_ceil_ftz_f:

case Intrinsic::nvvm_cos_approx_f:
case Intrinsic::nvvm_cos_approx_ftz_f:

case Intrinsic::nvvm_ex2_approx_d:
case Intrinsic::nvvm_ex2_approx_f:
case Intrinsic::nvvm_ex2_approx_ftz_f:

case Intrinsic::nvvm_fabs:
case Intrinsic::nvvm_fabs_ftz:

case Intrinsic::nvvm_floor_d:
case Intrinsic::nvvm_floor_f:
case Intrinsic::nvvm_floor_ftz_f:

case Intrinsic::nvvm_lg2_approx_d:
case Intrinsic::nvvm_lg2_approx_f:
case Intrinsic::nvvm_lg2_approx_ftz_f:

case Intrinsic::nvvm_rcp_rm_d:
case Intrinsic::nvvm_rcp_rm_f:
case Intrinsic::nvvm_rcp_rm_ftz_f:
case Intrinsic::nvvm_rcp_rn_d:
case Intrinsic::nvvm_rcp_rn_f:
case Intrinsic::nvvm_rcp_rn_ftz_f:
case Intrinsic::nvvm_rcp_rp_d:
case Intrinsic::nvvm_rcp_rp_f:
case Intrinsic::nvvm_rcp_rp_ftz_f:
case Intrinsic::nvvm_rcp_rz_d:
case Intrinsic::nvvm_rcp_rz_f:
case Intrinsic::nvvm_rcp_rz_ftz_f:
case Intrinsic::nvvm_rcp_approx_ftz_d:
case Intrinsic::nvvm_rcp_approx_ftz_f:

case Intrinsic::nvvm_round_d:
case Intrinsic::nvvm_round_f:
case Intrinsic::nvvm_round_ftz_f:

case Intrinsic::nvvm_rsqrt_approx_d:
case Intrinsic::nvvm_rsqrt_approx_f:
case Intrinsic::nvvm_rsqrt_approx_ftz_d:
case Intrinsic::nvvm_rsqrt_approx_ftz_f:

case Intrinsic::nvvm_saturate_d:
case Intrinsic::nvvm_saturate_f:
case Intrinsic::nvvm_saturate_ftz_f:

case Intrinsic::nvvm_sin_approx_f:
case Intrinsic::nvvm_sin_approx_ftz_f:

case Intrinsic::nvvm_sqrt_f:
case Intrinsic::nvvm_sqrt_rn_d:
case Intrinsic::nvvm_sqrt_rn_f:
case Intrinsic::nvvm_sqrt_rn_ftz_f:
case Intrinsic::nvvm_sqrt_approx_f:
case Intrinsic::nvvm_sqrt_approx_ftz_f:

// Sign operations are actually bitwise operations, they do not raise
// exceptions even for SNANs.
case Intrinsic::fabs:
Expand All @@ -1791,6 +1852,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::nearbyint:
case Intrinsic::rint:
case Intrinsic::canonicalize:

// Constrained intrinsics can be folded if FP environment is known
// to compiler.
case Intrinsic::experimental_constrained_fma:
Expand Down Expand Up @@ -1944,16 +2006,32 @@ static const APFloat FTZPreserveSign(const APFloat &V) {
return V;
}

Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V,
Type *Ty) {
// Get only the upper word of the input double in 1.11.20 format
// by making the lower 32-bits of the mantissa all 0.
static const APFloat ZeroLower32Bits(const APFloat &V) {
assert(V.getSizeInBits(V.getSemantics()) == 64);
uint64_t DoubleBits = V.bitcastToAPInt().getZExtValue();
DoubleBits &= 0xffffffff00000000;
return APFloat(V.getSemantics(), APInt(64, DoubleBits, false, false));
}

Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V, Type *Ty,
bool ShouldFTZPreservingSign = false) {
llvm_fenv_clearexcept();
double Result = NativeFP(V.convertToDouble());
auto Input = ShouldFTZPreservingSign ? FTZPreserveSign(V) : V;
double Result = NativeFP(Input.convertToDouble());
if (llvm_fenv_testexcept()) {
llvm_fenv_clearexcept();
return nullptr;
}

return GetConstantFoldFPValue(Result, Ty);
Constant *Output = GetConstantFoldFPValue(Result, Ty);
if (ShouldFTZPreservingSign) {
const auto *CFP = static_cast<ConstantFP *>(Output);
return ConstantFP::get(Ty->getContext(),
FTZPreserveSign(CFP->getValueAPF()));
}
return Output;
}

#if defined(HAS_IEE754_FLOAT128) && defined(HAS_LOGF128)
Expand Down Expand Up @@ -2524,6 +2602,152 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
return ConstantFoldFP(cosh, APF, Ty);
case Intrinsic::sqrt:
return ConstantFoldFP(sqrt, APF, Ty);

// NVVM Intrinsics:
case Intrinsic::nvvm_ceil_ftz_f:
case Intrinsic::nvvm_ceil_f:
case Intrinsic::nvvm_ceil_d:
return ConstantFoldFP(ceil, APF, Ty,
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));

case Intrinsic::nvvm_cos_approx_ftz_f:
case Intrinsic::nvvm_cos_approx_f:
return ConstantFoldFP(cos, APF, Ty,
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));

case Intrinsic::nvvm_ex2_approx_ftz_f:
case Intrinsic::nvvm_ex2_approx_d:
case Intrinsic::nvvm_ex2_approx_f:
return ConstantFoldFP(exp2, APF, Ty,
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));

case Intrinsic::nvvm_fabs_ftz:
case Intrinsic::nvvm_fabs:
return ConstantFoldFP(fabs, APF, Ty,
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));

case Intrinsic::nvvm_floor_ftz_f:
case Intrinsic::nvvm_floor_f:
case Intrinsic::nvvm_floor_d:
return ConstantFoldFP(floor, APF, Ty,
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));

case Intrinsic::nvvm_lg2_approx_ftz_f:
case Intrinsic::nvvm_lg2_approx_d:
case Intrinsic::nvvm_lg2_approx_f: {
if (APF.isNegative() || APF.isZero())
return nullptr;
return ConstantFoldFP(log2, APF, Ty,
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
}

case Intrinsic::nvvm_rcp_rm_ftz_f:
case Intrinsic::nvvm_rcp_rn_ftz_f:
case Intrinsic::nvvm_rcp_rp_ftz_f:
case Intrinsic::nvvm_rcp_rz_ftz_f:
case Intrinsic::nvvm_rcp_approx_ftz_f:
case Intrinsic::nvvm_rcp_approx_ftz_d:
case Intrinsic::nvvm_rcp_rm_d:
case Intrinsic::nvvm_rcp_rm_f:
case Intrinsic::nvvm_rcp_rn_d:
case Intrinsic::nvvm_rcp_rn_f:
case Intrinsic::nvvm_rcp_rp_d:
case Intrinsic::nvvm_rcp_rp_f:
case Intrinsic::nvvm_rcp_rz_d:
case Intrinsic::nvvm_rcp_rz_f: {
APFloat::roundingMode RoundMode = nvvm::GetRCPRoundingMode(IntrinsicID);
bool IsApprox = nvvm::RCPIsApprox(IntrinsicID);
bool IsFTZ = nvvm::RCPShouldFTZ(IntrinsicID);

auto Denominator = IsFTZ ? FTZPreserveSign(APF) : APF;
if (IntrinsicID == Intrinsic::nvvm_rcp_approx_ftz_d)
Denominator = ZeroLower32Bits(Denominator);
if (IsApprox && Denominator.isZero()) {
// According to the PTX spec, approximate rcp should return infinity
// with the same sign as the denominator when dividing by 0.
APFloat Inf = APFloat::getInf(APF.getSemantics(), APF.isNegative());
return ConstantFP::get(Ty->getContext(), Inf);
}
APFloat Res = APFloat::getOne(APF.getSemantics());
APFloat::opStatus Status = Res.divide(Denominator, RoundMode);

if (Status == APFloat::opOK || Status == APFloat::opInexact) {
if (IsFTZ)
Res = FTZPreserveSign(Res);
if (IntrinsicID == Intrinsic::nvvm_rcp_approx_ftz_d)
Res = ZeroLower32Bits(Res);
return ConstantFP::get(Ty->getContext(), Res);
}
return nullptr;
}

case Intrinsic::nvvm_round_ftz_f:
case Intrinsic::nvvm_round_f:
case Intrinsic::nvvm_round_d:
return ConstantFoldFP(round, APF, Ty,
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));

case Intrinsic::nvvm_rsqrt_approx_ftz_d:
case Intrinsic::nvvm_rsqrt_approx_ftz_f:
case Intrinsic::nvvm_rsqrt_approx_d:
case Intrinsic::nvvm_rsqrt_approx_f: {
bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID);
auto V = IsFTZ ? FTZPreserveSign(APF) : APF;

if (IntrinsicID == Intrinsic::nvvm_rsqrt_approx_ftz_d)
V = ZeroLower32Bits(V);

APFloat SqrtV(sqrt(V.convertToDouble()));

if (Ty->isFloatTy()) {
bool lost;
SqrtV.convert(APF.getSemantics(), APFloat::rmNearestTiesToEven,
&lost);
}

APFloat Res = APFloat::getOne(APF.getSemantics());
Res.divide(SqrtV, APFloat::rmNearestTiesToEven);

if (IntrinsicID == Intrinsic::nvvm_rsqrt_approx_ftz_d)
Res = ZeroLower32Bits(Res);

// We do not need to flush the output for ftz because it is impossible
// for 1/sqrt(x) to be a denormal value. If x is the largest fp value,
// sqrt(x) will be a number with the exponent approximately halved and
// the reciprocal of that number can't be small enough to be denormal.
return ConstantFP::get(Ty->getContext(), Res);
}

case Intrinsic::nvvm_saturate_ftz_f:
case Intrinsic::nvvm_saturate_d:
case Intrinsic::nvvm_saturate_f: {
bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID);
auto V = IsFTZ ? FTZPreserveSign(APF) : APF;
if (V.isNegative() || V.isZero() || V.isNaN())
return ConstantFP::getZero(Ty);
APFloat One = APFloat::getOne(APF.getSemantics());
if (V > One)
return ConstantFP::get(Ty->getContext(), One);
return ConstantFP::get(Ty->getContext(), APF);
}

case Intrinsic::nvvm_sin_approx_ftz_f:
case Intrinsic::nvvm_sin_approx_f:
return ConstantFoldFP(sin, APF, Ty,
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));

case Intrinsic::nvvm_sqrt_rn_ftz_f:
case Intrinsic::nvvm_sqrt_approx_ftz_f:
case Intrinsic::nvvm_sqrt_f:
case Intrinsic::nvvm_sqrt_rn_d:
case Intrinsic::nvvm_sqrt_rn_f:
case Intrinsic::nvvm_sqrt_approx_f:
if (APF.isNegative())
return nullptr;
return ConstantFoldFP(sqrt, APF, Ty,
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));

// AMDGCN Intrinsics:
case Intrinsic::amdgcn_cos:
case Intrinsic::amdgcn_sin: {
double V = getValueAsDouble(Op);
Expand Down
Loading
Loading