Skip to content
83 changes: 77 additions & 6 deletions llvm/include/llvm/IR/NVVMIntrinsicUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ inline bool FPToIntegerIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
return false;
}
llvm_unreachable("Checking FTZ flag for invalid f2i/d2i intrinsic");
return false;
}

inline bool FPToIntegerIntrinsicResultIsSigned(Intrinsic::ID IntrinsicID) {
Expand Down Expand Up @@ -179,7 +178,6 @@ inline bool FPToIntegerIntrinsicResultIsSigned(Intrinsic::ID IntrinsicID) {
}
llvm_unreachable(
"Checking invalid f2i/d2i intrinsic for signed int conversion");
return false;
}

inline APFloat::roundingMode
Expand Down Expand Up @@ -250,7 +248,6 @@ GetFPToIntegerRoundingMode(Intrinsic::ID IntrinsicID) {
return APFloat::rmTowardZero;
}
llvm_unreachable("Checking rounding mode for invalid f2i/d2i intrinsic");
return APFloat::roundingMode::Invalid;
}

inline bool FMinFMaxShouldFTZ(Intrinsic::ID IntrinsicID) {
Expand Down Expand Up @@ -280,7 +277,6 @@ inline bool FMinFMaxShouldFTZ(Intrinsic::ID IntrinsicID) {
return false;
}
llvm_unreachable("Checking FTZ flag for invalid fmin/fmax intrinsic");
return false;
}

inline bool FMinFMaxPropagatesNaNs(Intrinsic::ID IntrinsicID) {
Expand Down Expand Up @@ -310,7 +306,6 @@ inline bool FMinFMaxPropagatesNaNs(Intrinsic::ID IntrinsicID) {
return false;
}
llvm_unreachable("Checking NaN flag for invalid fmin/fmax intrinsic");
return false;
}

inline bool FMinFMaxIsXorSignAbs(Intrinsic::ID IntrinsicID) {
Expand Down Expand Up @@ -340,7 +335,83 @@ inline bool FMinFMaxIsXorSignAbs(Intrinsic::ID IntrinsicID) {
return false;
}
llvm_unreachable("Checking XorSignAbs flag for invalid fmin/fmax intrinsic");
return false;
}

inline bool UnaryMathIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
case Intrinsic::nvvm_ceil_ftz_f:
case Intrinsic::nvvm_fabs_ftz:
case Intrinsic::nvvm_floor_ftz_f:
case Intrinsic::nvvm_round_ftz_f:
case Intrinsic::nvvm_saturate_ftz_f:
case Intrinsic::nvvm_sqrt_rn_ftz_f:
return true;
case Intrinsic::nvvm_ceil_f:
case Intrinsic::nvvm_ceil_d:
case Intrinsic::nvvm_fabs:
case Intrinsic::nvvm_floor_f:
case Intrinsic::nvvm_floor_d:
case Intrinsic::nvvm_round_f:
case Intrinsic::nvvm_round_d:
case Intrinsic::nvvm_saturate_d:
case Intrinsic::nvvm_saturate_f:
case Intrinsic::nvvm_sqrt_f:
case Intrinsic::nvvm_sqrt_rn_d:
case Intrinsic::nvvm_sqrt_rn_f:
return false;
}
llvm_unreachable("Checking FTZ flag for invalid unary intrinsic");
}

inline bool RCPShouldFTZ(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
case Intrinsic::nvvm_rcp_rm_ftz_f:
case Intrinsic::nvvm_rcp_rn_ftz_f:
case Intrinsic::nvvm_rcp_rp_ftz_f:
case Intrinsic::nvvm_rcp_rz_ftz_f:
return true;
case Intrinsic::nvvm_rcp_rm_d:
case Intrinsic::nvvm_rcp_rm_f:
case Intrinsic::nvvm_rcp_rn_d:
case Intrinsic::nvvm_rcp_rn_f:
case Intrinsic::nvvm_rcp_rp_d:
case Intrinsic::nvvm_rcp_rp_f:
case Intrinsic::nvvm_rcp_rz_d:
case Intrinsic::nvvm_rcp_rz_f:
return false;
}
llvm_unreachable("Checking FTZ flag for invalid rcp intrinsic");
}

inline APFloat::roundingMode GetRCPRoundingMode(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
case Intrinsic::nvvm_rcp_rm_f:
case Intrinsic::nvvm_rcp_rm_d:
case Intrinsic::nvvm_rcp_rm_ftz_f:
return APFloat::rmTowardNegative;

case Intrinsic::nvvm_rcp_rn_f:
case Intrinsic::nvvm_rcp_rn_d:
case Intrinsic::nvvm_rcp_rn_ftz_f:
return APFloat::rmNearestTiesToEven;

case Intrinsic::nvvm_rcp_rp_f:
case Intrinsic::nvvm_rcp_rp_d:
case Intrinsic::nvvm_rcp_rp_ftz_f:
return APFloat::rmTowardPositive;

case Intrinsic::nvvm_rcp_rz_f:
case Intrinsic::nvvm_rcp_rz_d:
case Intrinsic::nvvm_rcp_rz_ftz_f:
return APFloat::rmTowardZero;
}
llvm_unreachable("Checking rounding mode for invalid rcp intrinsic");
}

inline DenormalMode GetNVVMDenromMode(bool ShouldFTZ) {
if (ShouldFTZ)
return DenormalMode::getPreserveSign();
return DenormalMode::getIEEE();
}

} // namespace nvvm
Expand Down
168 changes: 163 additions & 5 deletions llvm/lib/Analysis/ConstantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1799,6 +1799,44 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::nvvm_d2ull_rn:
case Intrinsic::nvvm_d2ull_rp:
case Intrinsic::nvvm_d2ull_rz:

// NVVM math intrinsics:
case Intrinsic::nvvm_ceil_d:
case Intrinsic::nvvm_ceil_f:
case Intrinsic::nvvm_ceil_ftz_f:

case Intrinsic::nvvm_fabs:
case Intrinsic::nvvm_fabs_ftz:

case Intrinsic::nvvm_floor_d:
case Intrinsic::nvvm_floor_f:
case Intrinsic::nvvm_floor_ftz_f:

case Intrinsic::nvvm_rcp_rm_d:
case Intrinsic::nvvm_rcp_rm_f:
case Intrinsic::nvvm_rcp_rm_ftz_f:
case Intrinsic::nvvm_rcp_rn_d:
case Intrinsic::nvvm_rcp_rn_f:
case Intrinsic::nvvm_rcp_rn_ftz_f:
case Intrinsic::nvvm_rcp_rp_d:
case Intrinsic::nvvm_rcp_rp_f:
case Intrinsic::nvvm_rcp_rp_ftz_f:
case Intrinsic::nvvm_rcp_rz_d:
case Intrinsic::nvvm_rcp_rz_f:
case Intrinsic::nvvm_rcp_rz_ftz_f:

case Intrinsic::nvvm_round_d:
case Intrinsic::nvvm_round_f:
case Intrinsic::nvvm_round_ftz_f:

case Intrinsic::nvvm_saturate_d:
case Intrinsic::nvvm_saturate_f:
case Intrinsic::nvvm_saturate_ftz_f:

case Intrinsic::nvvm_sqrt_f:
case Intrinsic::nvvm_sqrt_rn_d:
case Intrinsic::nvvm_sqrt_rn_f:
case Intrinsic::nvvm_sqrt_rn_ftz_f:
return !Call->isStrictFP();

// Sign operations are actually bitwise operations, they do not raise
Expand All @@ -1816,6 +1854,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::nearbyint:
case Intrinsic::rint:
case Intrinsic::canonicalize:

// Constrained intrinsics can be folded if FP environment is known
// to compiler.
case Intrinsic::experimental_constrained_fma:
Expand Down Expand Up @@ -1963,22 +2002,56 @@ inline bool llvm_fenv_testexcept() {
return false;
}

static APFloat FTZPreserveSign(const APFloat &V) {
static const APFloat FTZPreserveSign(const APFloat &V) {
if (V.isDenormal())
return APFloat::getZero(V.getSemantics(), V.isNegative());
return V;
}

Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V,
Type *Ty) {
static const APFloat FlushToPositiveZero(const APFloat &V) {
if (V.isDenormal())
return APFloat::getZero(V.getSemantics(), false);
return V;
}

static const APFloat
FlushWithDenormKind(const APFloat &V,
DenormalMode::DenormalModeKind DenormKind) {
assert(DenormKind != DenormalMode::DenormalModeKind::Invalid &&
DenormKind != DenormalMode::DenormalModeKind::Dynamic);
switch (DenormKind) {
case DenormalMode::DenormalModeKind::IEEE:
return V;
case DenormalMode::DenormalModeKind::PreserveSign:
return FTZPreserveSign(V);
case DenormalMode::DenormalModeKind::PositiveZero:
return FlushToPositiveZero(V);
default:
llvm_unreachable("Invalid denormal mode!");
}
}

Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V, Type *Ty,
DenormalMode DenormMode = DenormalMode::getIEEE()) {
if (!DenormMode.isValid() ||
DenormMode.Input == DenormalMode::DenormalModeKind::Dynamic ||
DenormMode.Output == DenormalMode::DenormalModeKind::Dynamic)
return nullptr;

llvm_fenv_clearexcept();
double Result = NativeFP(V.convertToDouble());
auto Input = FlushWithDenormKind(V, DenormMode.Input);
double Result = NativeFP(Input.convertToDouble());
if (llvm_fenv_testexcept()) {
llvm_fenv_clearexcept();
return nullptr;
}

return GetConstantFoldFPValue(Result, Ty);
Constant *Output = GetConstantFoldFPValue(Result, Ty);
if (DenormMode.Output == DenormalMode::DenormalModeKind::IEEE)
return Output;
const auto *CFP = static_cast<ConstantFP *>(Output);
const auto Res = FlushWithDenormKind(CFP->getValueAPF(), DenormMode.Output);
return ConstantFP::get(Ty->getContext(), Res);
}

#if defined(HAS_IEE754_FLOAT128) && defined(HAS_LOGF128)
Expand Down Expand Up @@ -2548,6 +2621,91 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
return ConstantFoldFP(atan, APF, Ty);
case Intrinsic::sqrt:
return ConstantFoldFP(sqrt, APF, Ty);

// NVVM Intrinsics:
case Intrinsic::nvvm_ceil_ftz_f:
case Intrinsic::nvvm_ceil_f:
case Intrinsic::nvvm_ceil_d:
return ConstantFoldFP(
ceil, APF, Ty,
nvvm::GetNVVMDenromMode(
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID)));

case Intrinsic::nvvm_fabs_ftz:
case Intrinsic::nvvm_fabs:
return ConstantFoldFP(
fabs, APF, Ty,
nvvm::GetNVVMDenromMode(
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID)));

case Intrinsic::nvvm_floor_ftz_f:
case Intrinsic::nvvm_floor_f:
case Intrinsic::nvvm_floor_d:
return ConstantFoldFP(
floor, APF, Ty,
nvvm::GetNVVMDenromMode(
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID)));

case Intrinsic::nvvm_rcp_rm_ftz_f:
case Intrinsic::nvvm_rcp_rn_ftz_f:
case Intrinsic::nvvm_rcp_rp_ftz_f:
case Intrinsic::nvvm_rcp_rz_ftz_f:
case Intrinsic::nvvm_rcp_rm_d:
case Intrinsic::nvvm_rcp_rm_f:
case Intrinsic::nvvm_rcp_rn_d:
case Intrinsic::nvvm_rcp_rn_f:
case Intrinsic::nvvm_rcp_rp_d:
case Intrinsic::nvvm_rcp_rp_f:
case Intrinsic::nvvm_rcp_rz_d:
case Intrinsic::nvvm_rcp_rz_f: {
APFloat::roundingMode RoundMode = nvvm::GetRCPRoundingMode(IntrinsicID);
bool IsFTZ = nvvm::RCPShouldFTZ(IntrinsicID);

auto Denominator = IsFTZ ? FTZPreserveSign(APF) : APF;
APFloat Res = APFloat::getOne(APF.getSemantics());
APFloat::opStatus Status = Res.divide(Denominator, RoundMode);

if (Status == APFloat::opOK || Status == APFloat::opInexact) {
if (IsFTZ)
Res = FTZPreserveSign(Res);
return ConstantFP::get(Ty->getContext(), Res);
}
return nullptr;
}

case Intrinsic::nvvm_round_ftz_f:
case Intrinsic::nvvm_round_f:
case Intrinsic::nvvm_round_d:
return ConstantFoldFP(
round, APF, Ty,
nvvm::GetNVVMDenromMode(
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID)));

case Intrinsic::nvvm_saturate_ftz_f:
case Intrinsic::nvvm_saturate_d:
case Intrinsic::nvvm_saturate_f: {
bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID);
auto V = IsFTZ ? FTZPreserveSign(APF) : APF;
if (V.isNegative() || V.isZero() || V.isNaN())
return ConstantFP::getZero(Ty);
APFloat One = APFloat::getOne(APF.getSemantics());
if (V > One)
return ConstantFP::get(Ty->getContext(), One);
return ConstantFP::get(Ty->getContext(), APF);
}

case Intrinsic::nvvm_sqrt_rn_ftz_f:
case Intrinsic::nvvm_sqrt_f:
case Intrinsic::nvvm_sqrt_rn_d:
case Intrinsic::nvvm_sqrt_rn_f:
if (APF.isNegative())
return nullptr;
return ConstantFoldFP(
sqrt, APF, Ty,
nvvm::GetNVVMDenromMode(
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID)));

// AMDGCN Intrinsics:
case Intrinsic::amdgcn_cos:
case Intrinsic::amdgcn_sin: {
double V = getValueAsDouble(Op);
Expand Down
Loading
Loading