Skip to content

Commit 216fdc8

Browse files
arsenmtru
authored andcommitted
AMDGPU: Fix fast f32 log/log10
OpenCL conformance didn't like interpreting afn as ignore the denormal handling. https://reviews.llvm.org/D157940 (cherry picked from commit 4b7b4b9)
1 parent 1ba6abc commit 216fdc8

File tree

6 files changed

+1310
-170
lines changed

6 files changed

+1310
-170
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2609,9 +2609,7 @@ SDValue AMDGPUTargetLowering::LowerFLOGCommon(SDValue Op,
26092609
X = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, X, Flags);
26102610
}
26112611

2612-
SDValue Lowered = LowerFLOGUnsafe(
2613-
X, DL, DAG, IsLog10 ? numbers::ln2 / numbers::ln10 : numbers::ln2,
2614-
Flags);
2612+
SDValue Lowered = LowerFLOGUnsafe(X, DL, DAG, IsLog10, Flags);
26152613
if (VT == MVT::f16 && !Subtarget->has16BitInsts()) {
26162614
return DAG.getNode(ISD::FP_ROUND, DL, VT, Lowered,
26172615
DAG.getTargetConstant(0, DL, MVT::i32), Flags);
@@ -2696,11 +2694,36 @@ SDValue AMDGPUTargetLowering::LowerFLOG10(SDValue Op, SelectionDAG &DAG) const {
26962694
// Do f32 fast math expansion for flog2 or flog10. This is accurate enough for a
26972695
// promote f16 operation.
26982696
SDValue AMDGPUTargetLowering::LowerFLOGUnsafe(SDValue Src, const SDLoc &SL,
2699-
SelectionDAG &DAG,
2700-
double Log2BaseInverted,
2697+
SelectionDAG &DAG, bool IsLog10,
27012698
SDNodeFlags Flags) const {
27022699
EVT VT = Src.getValueType();
27032700
unsigned LogOp = VT == MVT::f32 ? AMDGPUISD::LOG : ISD::FLOG2;
2701+
2702+
double Log2BaseInverted =
2703+
IsLog10 ? numbers::ln2 / numbers::ln10 : numbers::ln2;
2704+
2705+
if (VT == MVT::f32) {
2706+
auto [ScaledInput, IsScaled] = getScaledLogInput(DAG, SL, Src, Flags);
2707+
if (ScaledInput) {
2708+
SDValue LogSrc = DAG.getNode(AMDGPUISD::LOG, SL, VT, ScaledInput, Flags);
2709+
SDValue ScaledResultOffset =
2710+
DAG.getConstantFP(-32.0 * Log2BaseInverted, SL, VT);
2711+
2712+
SDValue Zero = DAG.getConstantFP(0.0f, SL, VT);
2713+
2714+
SDValue ResultOffset = DAG.getNode(ISD::SELECT, SL, VT, IsScaled,
2715+
ScaledResultOffset, Zero, Flags);
2716+
2717+
SDValue Log2Inv = DAG.getConstantFP(Log2BaseInverted, SL, VT);
2718+
2719+
if (Subtarget->hasFastFMAF32())
2720+
return DAG.getNode(ISD::FMA, SL, VT, LogSrc, Log2Inv, ResultOffset,
2721+
Flags);
2722+
SDValue Mul = DAG.getNode(ISD::FMUL, SL, VT, LogSrc, Log2Inv, Flags);
2723+
return DAG.getNode(ISD::FADD, SL, VT, Mul, ResultOffset);
2724+
}
2725+
}
2726+
27042727
SDValue Log2Operand = DAG.getNode(LogOp, SL, VT, Src, Flags);
27052728
SDValue Log2BaseInvertedOperand = DAG.getConstantFP(Log2BaseInverted, SL, VT);
27062729

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class AMDGPUTargetLowering : public TargetLowering {
7272
SDValue LowerFLOGCommon(SDValue Op, SelectionDAG &DAG) const;
7373
SDValue LowerFLOG10(SDValue Op, SelectionDAG &DAG) const;
7474
SDValue LowerFLOGUnsafe(SDValue Op, const SDLoc &SL, SelectionDAG &DAG,
75-
double Log2BaseInverted, SDNodeFlags Flags) const;
75+
bool IsLog10, SDNodeFlags Flags) const;
7676
SDValue lowerFEXP2(SDValue Op, SelectionDAG &DAG) const;
7777

7878
SDValue lowerFEXPUnsafe(SDValue Op, const SDLoc &SL, SelectionDAG &DAG,

llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3131,16 +3131,13 @@ bool AMDGPULegalizerInfo::legalizeFlogCommon(MachineInstr &MI,
31313131

31323132
if (Ty == F16 || MI.getFlag(MachineInstr::FmAfn) ||
31333133
TM.Options.ApproxFuncFPMath || TM.Options.UnsafeFPMath) {
3134-
const double Log2BaseInv =
3135-
IsLog10 ? numbers::ln2 / numbers::ln10 : numbers::ln2;
3136-
31373134
if (Ty == F16 && !ST.has16BitInsts()) {
31383135
Register LogVal = MRI.createGenericVirtualRegister(F32);
31393136
auto PromoteSrc = B.buildFPExt(F32, X);
3140-
legalizeFlogUnsafe(B, LogVal, PromoteSrc.getReg(0), Log2BaseInv, Flags);
3137+
legalizeFlogUnsafe(B, LogVal, PromoteSrc.getReg(0), IsLog10, Flags);
31413138
B.buildFPTrunc(Dst, LogVal);
31423139
} else {
3143-
legalizeFlogUnsafe(B, Dst, X, Log2BaseInv, Flags);
3140+
legalizeFlogUnsafe(B, Dst, X, IsLog10, Flags);
31443141
}
31453142

31463143
MI.eraseFromParent();
@@ -3224,10 +3221,36 @@ bool AMDGPULegalizerInfo::legalizeFlogCommon(MachineInstr &MI,
32243221
}
32253222

32263223
bool AMDGPULegalizerInfo::legalizeFlogUnsafe(MachineIRBuilder &B, Register Dst,
3227-
Register Src,
3228-
double Log2BaseInverted,
3224+
Register Src, bool IsLog10,
32293225
unsigned Flags) const {
3226+
const double Log2BaseInverted =
3227+
IsLog10 ? numbers::ln2 / numbers::ln10 : numbers::ln2;
3228+
32303229
LLT Ty = B.getMRI()->getType(Dst);
3230+
3231+
if (Ty == LLT::scalar(32)) {
3232+
auto [ScaledInput, IsScaled] = getScaledLogInput(B, Src, Flags);
3233+
if (ScaledInput) {
3234+
auto LogSrc = B.buildIntrinsic(Intrinsic::amdgcn_log, {Ty}, false)
3235+
.addUse(Src)
3236+
.setMIFlags(Flags);
3237+
auto ScaledResultOffset = B.buildFConstant(Ty, -32.0 * Log2BaseInverted);
3238+
auto Zero = B.buildFConstant(Ty, 0.0);
3239+
auto ResultOffset =
3240+
B.buildSelect(Ty, IsScaled, ScaledResultOffset, Zero, Flags);
3241+
auto Log2Inv = B.buildFConstant(Ty, Log2BaseInverted);
3242+
3243+
if (ST.hasFastFMAF32())
3244+
B.buildFMA(Dst, LogSrc, Log2Inv, ResultOffset, Flags);
3245+
else {
3246+
auto Mul = B.buildFMul(Ty, LogSrc, Log2Inv, Flags);
3247+
B.buildFAdd(Dst, Mul, ResultOffset, Flags);
3248+
}
3249+
3250+
return true;
3251+
}
3252+
}
3253+
32313254
auto Log2Operand = Ty == LLT::scalar(16)
32323255
? B.buildFLog2(Ty, Src, Flags)
32333256
: B.buildIntrinsic(Intrinsic::amdgcn_log, {Ty}, false)

llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class AMDGPULegalizerInfo final : public LegalizerInfo {
8585
bool legalizeFlog2(MachineInstr &MI, MachineIRBuilder &B) const;
8686
bool legalizeFlogCommon(MachineInstr &MI, MachineIRBuilder &B) const;
8787
bool legalizeFlogUnsafe(MachineIRBuilder &B, Register Dst, Register Src,
88-
double Log2BaseInverted, unsigned Flags) const;
88+
bool IsLog10, unsigned Flags) const;
8989
bool legalizeFExp2(MachineInstr &MI, MachineIRBuilder &B) const;
9090
bool legalizeFExpUnsafe(MachineIRBuilder &B, Register Dst, Register Src,
9191
unsigned Flags) const;

0 commit comments

Comments
 (0)