@@ -219,6 +219,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
219219 setOperationAction(ISD::SELECT, MVT::f64, Promote);
220220 AddPromotedToType(ISD::SELECT, MVT::f64, MVT::i64);
221221
222+ setOperationAction(ISD::FSQRT, MVT::f64, Custom);
223+
222224 setOperationAction(ISD::SELECT_CC,
223225 {MVT::f32, MVT::i32, MVT::i64, MVT::f64, MVT::i1}, Expand);
224226
@@ -4924,7 +4926,10 @@ SDValue SITargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
49244926 "Load should return a value and a chain");
49254927 return Result;
49264928 }
4927-
4929+ case ISD::FSQRT:
4930+ if (Op.getValueType() == MVT::f64)
4931+ return lowerFSQRTF64(Op, DAG);
4932+ return SDValue();
49284933 case ISD::FSIN:
49294934 case ISD::FCOS:
49304935 return LowerTrig(Op, DAG);
@@ -9749,6 +9754,87 @@ SDValue SITargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
97499754 return SDValue();
97509755}
97519756
9757+ SDValue SITargetLowering::lowerFSQRTF64(SDValue Op, SelectionDAG &DAG) const {
9758+ // For double type, the SQRT and RSQ instructions don't have required
9759+ // precision, we apply Goldschmidt's algorithm to improve the result:
9760+ //
9761+ // y0 = rsq(x)
9762+ // g0 = x * y0
9763+ // h0 = 0.5 * y0
9764+ //
9765+ // r0 = 0.5 - h0 * g0
9766+ // g1 = g0 * r0 + g0
9767+ // h1 = h0 * r0 + h0
9768+ //
9769+ // r1 = 0.5 - h1 * g1 => d0 = x - g1 * g1
9770+ // g2 = g1 * r1 + g1 g2 = d0 * h1 + g1
9771+ // h2 = h1 * r1 + h1
9772+ //
9773+ // r2 = 0.5 - h2 * g2 => d1 = x - g2 * g2
9774+ // g3 = g2 * r2 + g2 g3 = d1 * h1 + g2
9775+ //
9776+ // sqrt(x) = g3
9777+
9778+ SDNodeFlags Flags = Op->getFlags();
9779+
9780+ SDLoc DL(Op);
9781+
9782+ SDValue X = Op.getOperand(0);
9783+ SDValue ScaleConstant = DAG.getConstantFP(0x1.0p-767, DL, MVT::f64);
9784+
9785+ SDValue Scaling = DAG.getSetCC(DL, MVT::i1, X, ScaleConstant, ISD::SETOLT);
9786+
9787+ SDValue ZeroInt = DAG.getConstant(0, DL, MVT::i32);
9788+
9789+ // Scale up input if it is too small.
9790+ SDValue ScaleUpFactor = DAG.getConstant(256, DL, MVT::i32);
9791+ SDValue ScaleUp =
9792+ DAG.getNode(ISD::SELECT, DL, MVT::i32, Scaling, ScaleUpFactor, ZeroInt);
9793+ SDValue SqrtX = DAG.getNode(ISD::FLDEXP, DL, MVT::f64, X, ScaleUp, Flags);
9794+
9795+ SDValue SqrtY = DAG.getNode(AMDGPUISD::RSQ, DL, MVT::f64, SqrtX);
9796+
9797+ SDValue SqrtS0 = DAG.getNode(ISD::FMUL, DL, MVT::f64, SqrtX, SqrtY);
9798+
9799+ SDValue Half = DAG.getConstantFP(0.5, DL, MVT::f64);
9800+ SDValue SqrtH0 = DAG.getNode(ISD::FMUL, DL, MVT::f64, SqrtY, Half);
9801+
9802+ SDValue NegSqrtH0 = DAG.getNode(ISD::FNEG, DL, MVT::f64, SqrtH0);
9803+ SDValue SqrtR0 = DAG.getNode(ISD::FMA, DL, MVT::f64, NegSqrtH0, SqrtS0, Half);
9804+
9805+ SDValue SqrtH1 = DAG.getNode(ISD::FMA, DL, MVT::f64, SqrtH0, SqrtR0, SqrtH0);
9806+
9807+ SDValue SqrtS1 = DAG.getNode(ISD::FMA, DL, MVT::f64, SqrtS0, SqrtR0, SqrtS0);
9808+
9809+ SDValue NegSqrtS1 = DAG.getNode(ISD::FNEG, DL, MVT::f64, SqrtS1);
9810+ SDValue SqrtD0 = DAG.getNode(ISD::FMA, DL, MVT::f64, NegSqrtS1, SqrtS1, SqrtX);
9811+
9812+ SDValue SqrtS2 = DAG.getNode(ISD::FMA, DL, MVT::f64, SqrtD0, SqrtH1, SqrtS1);
9813+
9814+ SDValue NegSqrtS2 = DAG.getNode(ISD::FNEG, DL, MVT::f64, SqrtS2);
9815+ SDValue SqrtD1 =
9816+ DAG.getNode(ISD::FMA, DL, MVT::f64, NegSqrtS2, SqrtS2, SqrtX);
9817+
9818+ SDValue SqrtRet = DAG.getNode(ISD::FMA, DL, MVT::f64, SqrtD1, SqrtH1, SqrtS2);
9819+
9820+ SDValue ScaleDownFactor = DAG.getConstant(-128, DL, MVT::i32);
9821+ SDValue ScaleDown =
9822+ DAG.getNode(ISD::SELECT, DL, MVT::i32, Scaling, ScaleDownFactor, ZeroInt);
9823+ SqrtRet = DAG.getNode(ISD::FLDEXP, DL, MVT::f64, SqrtRet, ScaleDown, Flags);
9824+
9825+ // TODO: Switch to fcmp oeq 0 for finite only. Can't fully remove this check
9826+ // with finite only or nsz because rsq(+/-0) = +/-inf
9827+
9828+ // TODO: Check for DAZ and expand to subnormals
9829+ SDValue IsZeroOrInf =
9830+ DAG.getNode(ISD::IS_FPCLASS, DL, MVT::i1, SqrtX,
9831+ DAG.getTargetConstant(fcZero | fcPosInf, DL, MVT::i32));
9832+
9833+ // If x is +INF, +0, or -0, use its original value
9834+ return DAG.getNode(ISD::SELECT, DL, MVT::f64, IsZeroOrInf, SqrtX, SqrtRet,
9835+ Flags);
9836+ }
9837+
97529838SDValue SITargetLowering::LowerTrig(SDValue Op, SelectionDAG &DAG) const {
97539839 SDLoc DL(Op);
97549840 EVT VT = Op.getValueType();
0 commit comments