Skip to content

Commit 630fd78

Browse files
arsenmgithub-actions[bot]
authored andcommitted
Automerge: AMDGPU: Improve exp10 lowering for f16 (#170771)
2 parents 3596d92 + 63e9d60 commit 630fd78

File tree

3 files changed

+72
-179
lines changed

3 files changed

+72
-179
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2961,19 +2961,28 @@ SDValue AMDGPUTargetLowering::lowerFEXP2(SDValue Op, SelectionDAG &DAG) const {
29612961
return DAG.getNode(ISD::FMUL, SL, VT, Exp2, ResultScale, Flags);
29622962
}
29632963

2964+
SDValue AMDGPUTargetLowering::lowerFEXPUnsafeImpl(SDValue X, const SDLoc &SL,
2965+
SelectionDAG &DAG,
2966+
SDNodeFlags Flags,
2967+
bool IsExp10) const {
2968+
// exp(x) -> exp2(M_LOG2E_F * x);
2969+
// exp10(x) -> exp2(log2(10) * x);
2970+
EVT VT = X.getValueType();
2971+
SDValue Const =
2972+
DAG.getConstantFP(IsExp10 ? 0x1.a934f0p+1f : numbers::log2e, SL, VT);
2973+
2974+
SDValue Mul = DAG.getNode(ISD::FMUL, SL, VT, X, Const, Flags);
2975+
return DAG.getNode(VT == MVT::f32 ? (unsigned)AMDGPUISD::EXP
2976+
: (unsigned)ISD::FEXP2,
2977+
SL, VT, Mul, Flags);
2978+
}
2979+
29642980
SDValue AMDGPUTargetLowering::lowerFEXPUnsafe(SDValue X, const SDLoc &SL,
29652981
SelectionDAG &DAG,
29662982
SDNodeFlags Flags) const {
29672983
EVT VT = X.getValueType();
2968-
const SDValue Log2E = DAG.getConstantFP(numbers::log2e, SL, VT);
2969-
2970-
if (VT != MVT::f32 || !needsDenormHandlingF32(DAG, X, Flags)) {
2971-
// exp2(M_LOG2E_F * f);
2972-
SDValue Mul = DAG.getNode(ISD::FMUL, SL, VT, X, Log2E, Flags);
2973-
return DAG.getNode(VT == MVT::f32 ? (unsigned)AMDGPUISD::EXP
2974-
: (unsigned)ISD::FEXP2,
2975-
SL, VT, Mul, Flags);
2976-
}
2984+
if (VT != MVT::f32 || !needsDenormHandlingF32(DAG, X, Flags))
2985+
return lowerFEXPUnsafeImpl(X, SL, DAG, Flags, /*IsExp10=*/false);
29772986

29782987
EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
29792988

@@ -2987,6 +2996,7 @@ SDValue AMDGPUTargetLowering::lowerFEXPUnsafe(SDValue X, const SDLoc &SL,
29872996
SDValue AdjustedX =
29882997
DAG.getNode(ISD::SELECT, SL, VT, NeedsScaling, ScaledX, X);
29892998

2999+
const SDValue Log2E = DAG.getConstantFP(numbers::log2e, SL, VT);
29903000
SDValue ExpInput = DAG.getNode(ISD::FMUL, SL, VT, AdjustedX, Log2E, Flags);
29913001

29923002
SDValue Exp2 = DAG.getNode(AMDGPUISD::EXP, SL, VT, ExpInput, Flags);
@@ -3005,6 +3015,7 @@ SDValue AMDGPUTargetLowering::lowerFEXP10Unsafe(SDValue X, const SDLoc &SL,
30053015
SelectionDAG &DAG,
30063016
SDNodeFlags Flags) const {
30073017
const EVT VT = X.getValueType();
3018+
30083019
const unsigned Exp2Op = VT == MVT::f32 ? static_cast<unsigned>(AMDGPUISD::EXP)
30093020
: static_cast<unsigned>(ISD::FEXP2);
30103021

@@ -3072,13 +3083,15 @@ SDValue AMDGPUTargetLowering::lowerFEXP(SDValue Op, SelectionDAG &DAG) const {
30723083
if (VT.isVector())
30733084
return SDValue();
30743085

3086+
// Nothing in half is a denormal when promoted to f32.
3087+
//
30753088
// exp(f16 x) ->
30763089
// fptrunc (v_exp_f32 (fmul (fpext x), log2e))
3077-
3078-
// Nothing in half is a denormal when promoted to f32.
3090+
//
3091+
// exp10(f16 x) ->
3092+
// fptrunc (v_exp_f32 (fmul (fpext x), log2(10)))
30793093
SDValue Ext = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, X, Flags);
3080-
SDValue Lowered = IsExp10 ? lowerFEXP10Unsafe(Ext, SL, DAG, Flags)
3081-
: lowerFEXPUnsafe(Ext, SL, DAG, Flags);
3094+
SDValue Lowered = lowerFEXPUnsafeImpl(Ext, SL, DAG, Flags, IsExp10);
30823095
return DAG.getNode(ISD::FP_ROUND, SL, VT, Lowered,
30833096
DAG.getTargetConstant(0, SL, MVT::i32), Flags);
30843097
}

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ class AMDGPUTargetLowering : public TargetLowering {
7878
bool IsLog10, SDNodeFlags Flags) const;
7979
SDValue lowerFEXP2(SDValue Op, SelectionDAG &DAG) const;
8080

81+
SDValue lowerFEXPUnsafeImpl(SDValue Op, const SDLoc &SL, SelectionDAG &DAG,
82+
SDNodeFlags Flags, bool IsExp10) const;
83+
8184
SDValue lowerFEXPUnsafe(SDValue Op, const SDLoc &SL, SelectionDAG &DAG,
8285
SDNodeFlags Flags) const;
8386
SDValue lowerFEXP10Unsafe(SDValue Op, const SDLoc &SL, SelectionDAG &DAG,

0 commit comments

Comments
 (0)