Skip to content

Commit c887ade

Browse files
committed
[TargetLowering] Consider fast-math flags in getSqrtInputTest
1 parent a611074 commit c887ade

File tree

7 files changed

+28
-14
lines changed

7 files changed

+28
-14
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5356,7 +5356,8 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
53565356
/// comparison may check if the operand is NAN, INF, zero, normal, etc. The
53575357
/// result should be used as the condition operand for a select or branch.
53585358
virtual SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
5359-
const DenormalMode &Mode) const;
5359+
const DenormalMode &Mode,
5360+
SDNodeFlags Flags) const;
53605361

53615362
/// Return a target-dependent result if the input operand is not suitable for
53625363
/// use with a square root estimate calculation.

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29832,7 +29832,8 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
2983229832
if (!Reciprocal) {
2983329833
SDLoc DL(Op);
2983429834
// Try the target specific test first.
29835-
SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
29835+
SDValue Test =
29836+
TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT), Flags);
2983629837

2983729838
// The estimate is now completely wrong if the input was exactly 0.0 or
2983829839
// possibly a denormal. Force the answer to 0.0 or value provided by

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7451,7 +7451,8 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
74517451
}
74527452

74537453
SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
7454-
const DenormalMode &Mode) const {
7454+
const DenormalMode &Mode,
7455+
SDNodeFlags Flags) const {
74557456
SDLoc DL(Op);
74567457
EVT VT = Op.getValueType();
74577458
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
@@ -7462,7 +7463,10 @@ SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
74627463
if (Mode.Input == DenormalMode::PreserveSign ||
74637464
Mode.Input == DenormalMode::PositiveZero) {
74647465
// Test = X == 0.0
7465-
return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
7466+
SDValue Test = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
7467+
// Propagate fast-math flags from fcmp.
7468+
Test->setFlags(Flags);
7469+
return Test;
74667470
}
74677471

74687472
// Testing it with denormal inputs to avoid wrong estimate.
@@ -7471,8 +7475,11 @@ SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
74717475
const fltSemantics &FltSem = VT.getFltSemantics();
74727476
APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
74737477
SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
7474-
SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
7475-
return DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
7478+
SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op, Flags);
7479+
SDValue Test = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
7480+
// Propagate fast-math flags from fcmp.
7481+
Test->setFlags(Flags);
7482+
return Test;
74767483
}
74777484

74787485
SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12701,13 +12701,15 @@ static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
1270112701
return SDValue();
1270212702
}
1270312703

12704-
SDValue
12705-
AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
12706-
const DenormalMode &Mode) const {
12704+
SDValue AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
12705+
const DenormalMode &Mode,
12706+
SDNodeFlags Flags) const {
1270712707
SDLoc DL(Op);
1270812708
EVT VT = Op.getValueType();
1270912709
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
1271012710
SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
12711+
SDValue Test = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
12712+
Test->setFlags(Flags);
1271112713
return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
1271212714
}
1271312715

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,8 @@ class AArch64TargetLowering : public TargetLowering {
799799
SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
800800
int &ExtraSteps) const override;
801801
SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
802-
const DenormalMode &Mode) const override;
802+
const DenormalMode &Mode,
803+
SDNodeFlags Flags) const override;
803804
SDValue getSqrtResultForDenormInput(SDValue Operand,
804805
SelectionDAG &DAG) const override;
805806
unsigned combineRepeatedFPDivisors() const override;

llvm/lib/Target/PowerPC/PPCISelLowering.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14650,17 +14650,18 @@ static int getEstimateRefinementSteps(EVT VT, const PPCSubtarget &Subtarget) {
1465014650
}
1465114651

1465214652
SDValue PPCTargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
14653-
const DenormalMode &Mode) const {
14653+
const DenormalMode &Mode,
14654+
SDNodeFlags Flags) const {
1465414655
// We only have VSX Vector Test for software Square Root.
1465514656
EVT VT = Op.getValueType();
1465614657
if (!isTypeLegal(MVT::i1) ||
1465714658
(VT != MVT::f64 &&
1465814659
((VT != MVT::v2f64 && VT != MVT::v4f32) || !Subtarget.hasVSX())))
14659-
return TargetLowering::getSqrtInputTest(Op, DAG, Mode);
14660+
return TargetLowering::getSqrtInputTest(Op, DAG, Mode, Flags);
1466014661

1466114662
SDLoc DL(Op);
1466214663
// The output register of FTSQRT is CR field.
14663-
SDValue FTSQRT = DAG.getNode(PPCISD::FTSQRT, DL, MVT::i32, Op);
14664+
SDValue FTSQRT = DAG.getNode(PPCISD::FTSQRT, DL, MVT::i32, Op, Flags);
1466414665
// ftsqrt BF,FRB
1466514666
// Let e_b be the unbiased exponent of the double-precision
1466614667
// floating-point operand in register FRB.

llvm/lib/Target/PowerPC/PPCISelLowering.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1463,7 +1463,8 @@ namespace llvm {
14631463
SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
14641464
int &RefinementSteps) const override;
14651465
SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
1466-
const DenormalMode &Mode) const override;
1466+
const DenormalMode &Mode,
1467+
SDNodeFlags Flags) const override;
14671468
SDValue getSqrtResultForDenormInput(SDValue Operand,
14681469
SelectionDAG &DAG) const override;
14691470
unsigned combineRepeatedFPDivisors() const override;

0 commit comments

Comments
 (0)