Skip to content

Conversation

paperchalice
Copy link
Contributor

Propagate fast-math flags when lowering fsqrt, this is required by PowerPC.

@paperchalice paperchalice marked this pull request as ready for review October 12, 2025 12:42
@paperchalice paperchalice added the floating-point Floating-point math label Oct 12, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 12, 2025

@llvm/pr-subscribers-llvm-selectiondag
@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-backend-powerpc

Author: None (paperchalice)

Changes

Propagate fast-math flags when lowering fsqrt, this is required by PowerPC.


Full diff: https://github.com/llvm/llvm-project/pull/163061.diff

8 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+2-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+2-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+11-4)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+5-3)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+2-1)
  • (modified) llvm/lib/Target/PowerPC/PPCISelLowering.cpp (+4-3)
  • (modified) llvm/lib/Target/PowerPC/PPCISelLowering.h (+2-1)
  • (modified) llvm/test/CodeGen/PowerPC/fmf-propagation.ll (+40-50)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 73f2c55a71125..9c9f22f6febfa 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5356,7 +5356,8 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// comparison may check if the operand is NAN, INF, zero, normal, etc. The
   /// result should be used as the condition operand for a select or branch.
   virtual SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
-                                   const DenormalMode &Mode) const;
+                                   const DenormalMode &Mode,
+                                   SDNodeFlags Flags) const;
 
   /// Return a target-dependent result if the input operand is not suitable for
   /// use with a square root estimate calculation.
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b1accdd066dfd..af725923b4597 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -29832,7 +29832,8 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
     if (!Reciprocal) {
       SDLoc DL(Op);
       // Try the target specific test first.
-      SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
+      SDValue Test =
+          TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT), Flags);
 
       // The estimate is now completely wrong if the input was exactly 0.0 or
       // possibly a denormal. Force the answer to 0.0 or value provided by
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index cc503d324e74b..780d8a090ba18 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -7451,7 +7451,8 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
 }
 
 SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
-                                         const DenormalMode &Mode) const {
+                                         const DenormalMode &Mode,
+                                         SDNodeFlags Flags) const {
   SDLoc DL(Op);
   EVT VT = Op.getValueType();
   EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
@@ -7462,7 +7463,10 @@ SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
   if (Mode.Input == DenormalMode::PreserveSign ||
       Mode.Input == DenormalMode::PositiveZero) {
     // Test = X == 0.0
-    return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
+    SDValue Test = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
+    // Propagate fast-math flags from fcmp.
+    Test->setFlags(Flags);
+    return Test;
   }
 
   // Testing it with denormal inputs to avoid wrong estimate.
@@ -7471,8 +7475,11 @@ SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
   const fltSemantics &FltSem = VT.getFltSemantics();
   APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
   SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
-  SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
-  return DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
+  SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op, Flags);
+  SDValue Test = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
+  // Propagate fast-math flags from fcmp.
+  Test->setFlags(Flags);
+  return Test;
 }
 
 SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fbce3b0efbb7c..17aeee0f796c1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -12701,13 +12701,15 @@ static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
   return SDValue();
 }
 
-SDValue
-AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
-                                        const DenormalMode &Mode) const {
+SDValue AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
+                                                const DenormalMode &Mode,
+                                                SDNodeFlags Flags) const {
   SDLoc DL(Op);
   EVT VT = Op.getValueType();
   EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
   SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
+  SDValue Test = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
+  Test->setFlags(Flags);
   return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 00956fdc8e48e..89177b7962e41 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -799,7 +799,8 @@ class AArch64TargetLowering : public TargetLowering {
   SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
                            int &ExtraSteps) const override;
   SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
-                           const DenormalMode &Mode) const override;
+                           const DenormalMode &Mode,
+                           SDNodeFlags Flags) const override;
   SDValue getSqrtResultForDenormInput(SDValue Operand,
                                       SelectionDAG &DAG) const override;
   unsigned combineRepeatedFPDivisors() const override;
diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
index 944a1e2e6fa17..5a8f4012ec70f 100644
--- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
+++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
@@ -14650,17 +14650,18 @@ static int getEstimateRefinementSteps(EVT VT, const PPCSubtarget &Subtarget) {
 }
 
 SDValue PPCTargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
-                                            const DenormalMode &Mode) const {
+                                            const DenormalMode &Mode,
+                                            SDNodeFlags Flags) const {
   // We only have VSX Vector Test for software Square Root.
   EVT VT = Op.getValueType();
   if (!isTypeLegal(MVT::i1) ||
       (VT != MVT::f64 &&
        ((VT != MVT::v2f64 && VT != MVT::v4f32) || !Subtarget.hasVSX())))
-    return TargetLowering::getSqrtInputTest(Op, DAG, Mode);
+    return TargetLowering::getSqrtInputTest(Op, DAG, Mode, Flags);
 
   SDLoc DL(Op);
   // The output register of FTSQRT is CR field.
-  SDValue FTSQRT = DAG.getNode(PPCISD::FTSQRT, DL, MVT::i32, Op);
+  SDValue FTSQRT = DAG.getNode(PPCISD::FTSQRT, DL, MVT::i32, Op, Flags);
   // ftsqrt BF,FRB
   // Let e_b be the unbiased exponent of the double-precision
   // floating-point operand in register FRB.
diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.h b/llvm/lib/Target/PowerPC/PPCISelLowering.h
index 59f338782ba43..4daef49cdf030 100644
--- a/llvm/lib/Target/PowerPC/PPCISelLowering.h
+++ b/llvm/lib/Target/PowerPC/PPCISelLowering.h
@@ -1463,7 +1463,8 @@ namespace llvm {
     SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
                              int &RefinementSteps) const override;
     SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
-                             const DenormalMode &Mode) const override;
+                             const DenormalMode &Mode,
+                             SDNodeFlags Flags) const override;
     SDValue getSqrtResultForDenormInput(SDValue Operand,
                                         SelectionDAG &DAG) const override;
     unsigned combineRepeatedFPDivisors() const override;
diff --git a/llvm/test/CodeGen/PowerPC/fmf-propagation.ll b/llvm/test/CodeGen/PowerPC/fmf-propagation.ll
index e71f59c79ce4d..cad684ecf0170 100644
--- a/llvm/test/CodeGen/PowerPC/fmf-propagation.ll
+++ b/llvm/test/CodeGen/PowerPC/fmf-propagation.ll
@@ -325,24 +325,21 @@ define float @sqrt_afn_ieee(float %x) #0 {
 ;
 ; GLOBAL-LABEL: sqrt_afn_ieee:
 ; GLOBAL:       # %bb.0:
-; GLOBAL-NEXT:    addis 3, 2, .LCPI11_1@toc@ha
-; GLOBAL-NEXT:    xsabsdp 0, 1
-; GLOBAL-NEXT:    lfs 2, .LCPI11_1@toc@l(3)
-; GLOBAL-NEXT:    fcmpu 0, 0, 2
-; GLOBAL-NEXT:    xxlxor 0, 0, 0
-; GLOBAL-NEXT:    blt 0, .LBB11_2
-; GLOBAL-NEXT:  # %bb.1:
 ; GLOBAL-NEXT:    xsrsqrtesp 0, 1
 ; GLOBAL-NEXT:    vspltisw 2, -3
 ; GLOBAL-NEXT:    addis 3, 2, .LCPI11_0@toc@ha
-; GLOBAL-NEXT:    xvcvsxwdp 2, 34
-; GLOBAL-NEXT:    xsmulsp 1, 1, 0
-; GLOBAL-NEXT:    xsmaddasp 2, 1, 0
+; GLOBAL-NEXT:    xvcvsxwdp 3, 34
+; GLOBAL-NEXT:    xsmulsp 2, 1, 0
+; GLOBAL-NEXT:    xsabsdp 1, 1
+; GLOBAL-NEXT:    xsmaddasp 3, 2, 0
 ; GLOBAL-NEXT:    lfs 0, .LCPI11_0@toc@l(3)
-; GLOBAL-NEXT:    xsmulsp 0, 1, 0
-; GLOBAL-NEXT:    xsmulsp 0, 0, 2
-; GLOBAL-NEXT:  .LBB11_2:
-; GLOBAL-NEXT:    fmr 1, 0
+; GLOBAL-NEXT:    addis 3, 2, .LCPI11_1@toc@ha
+; GLOBAL-NEXT:    xsmulsp 0, 2, 0
+; GLOBAL-NEXT:    lfs 2, .LCPI11_1@toc@l(3)
+; GLOBAL-NEXT:    xssubsp 1, 1, 2
+; GLOBAL-NEXT:    xxlxor 2, 2, 2
+; GLOBAL-NEXT:    xsmulsp 0, 0, 3
+; GLOBAL-NEXT:    fsel 1, 1, 0, 2
 ; GLOBAL-NEXT:    blr
   %rt = call afn ninf float @llvm.sqrt.f32(float %x)
   ret float %rt
@@ -393,21 +390,19 @@ define float @sqrt_afn_preserve_sign(float %x) #1 {
 ;
 ; GLOBAL-LABEL: sqrt_afn_preserve_sign:
 ; GLOBAL:       # %bb.0:
-; GLOBAL-NEXT:    xxlxor 0, 0, 0
-; GLOBAL-NEXT:    fcmpu 0, 1, 0
-; GLOBAL-NEXT:    beq 0, .LBB13_2
-; GLOBAL-NEXT:  # %bb.1:
 ; GLOBAL-NEXT:    xsrsqrtesp 0, 1
 ; GLOBAL-NEXT:    vspltisw 2, -3
 ; GLOBAL-NEXT:    addis 3, 2, .LCPI13_0@toc@ha
-; GLOBAL-NEXT:    xvcvsxwdp 2, 34
-; GLOBAL-NEXT:    xsmulsp 1, 1, 0
-; GLOBAL-NEXT:    xsmaddasp 2, 1, 0
+; GLOBAL-NEXT:    xvcvsxwdp 3, 34
+; GLOBAL-NEXT:    xsmulsp 2, 1, 0
+; GLOBAL-NEXT:    xsmaddasp 3, 2, 0
 ; GLOBAL-NEXT:    lfs 0, .LCPI13_0@toc@l(3)
-; GLOBAL-NEXT:    xsmulsp 0, 1, 0
-; GLOBAL-NEXT:    xsmulsp 0, 0, 2
-; GLOBAL-NEXT:  .LBB13_2:
-; GLOBAL-NEXT:    fmr 1, 0
+; GLOBAL-NEXT:    xsmulsp 0, 2, 0
+; GLOBAL-NEXT:    xxlxor 2, 2, 2
+; GLOBAL-NEXT:    xsmulsp 0, 0, 3
+; GLOBAL-NEXT:    fsel 2, 1, 2, 0
+; GLOBAL-NEXT:    xsnegdp 1, 1
+; GLOBAL-NEXT:    fsel 1, 1, 2, 0
 ; GLOBAL-NEXT:    blr
   %rt = call afn ninf float @llvm.sqrt.f32(float %x)
   ret float %rt
@@ -462,24 +457,21 @@ define float @sqrt_fast_ieee(float %x) #0 {
 ;
 ; GLOBAL-LABEL: sqrt_fast_ieee:
 ; GLOBAL:       # %bb.0:
-; GLOBAL-NEXT:    addis 3, 2, .LCPI15_1@toc@ha
-; GLOBAL-NEXT:    xsabsdp 0, 1
-; GLOBAL-NEXT:    lfs 2, .LCPI15_1@toc@l(3)
-; GLOBAL-NEXT:    fcmpu 0, 0, 2
-; GLOBAL-NEXT:    xxlxor 0, 0, 0
-; GLOBAL-NEXT:    blt 0, .LBB15_2
-; GLOBAL-NEXT:  # %bb.1:
 ; GLOBAL-NEXT:    xsrsqrtesp 0, 1
 ; GLOBAL-NEXT:    vspltisw 2, -3
 ; GLOBAL-NEXT:    addis 3, 2, .LCPI15_0@toc@ha
-; GLOBAL-NEXT:    xvcvsxwdp 2, 34
-; GLOBAL-NEXT:    xsmulsp 1, 1, 0
-; GLOBAL-NEXT:    xsmaddasp 2, 1, 0
+; GLOBAL-NEXT:    xvcvsxwdp 3, 34
+; GLOBAL-NEXT:    xsmulsp 2, 1, 0
+; GLOBAL-NEXT:    xsabsdp 1, 1
+; GLOBAL-NEXT:    xsmaddasp 3, 2, 0
 ; GLOBAL-NEXT:    lfs 0, .LCPI15_0@toc@l(3)
-; GLOBAL-NEXT:    xsmulsp 0, 1, 0
-; GLOBAL-NEXT:    xsmulsp 0, 0, 2
-; GLOBAL-NEXT:  .LBB15_2:
-; GLOBAL-NEXT:    fmr 1, 0
+; GLOBAL-NEXT:    addis 3, 2, .LCPI15_1@toc@ha
+; GLOBAL-NEXT:    xsmulsp 0, 2, 0
+; GLOBAL-NEXT:    lfs 2, .LCPI15_1@toc@l(3)
+; GLOBAL-NEXT:    xssubsp 1, 1, 2
+; GLOBAL-NEXT:    xxlxor 2, 2, 2
+; GLOBAL-NEXT:    xsmulsp 0, 0, 3
+; GLOBAL-NEXT:    fsel 1, 1, 0, 2
 ; GLOBAL-NEXT:    blr
   %rt = call contract reassoc afn ninf float @llvm.sqrt.f32(float %x)
   ret float %rt
@@ -517,21 +509,19 @@ define float @sqrt_fast_preserve_sign(float %x) #1 {
 ;
 ; GLOBAL-LABEL: sqrt_fast_preserve_sign:
 ; GLOBAL:       # %bb.0:
-; GLOBAL-NEXT:    xxlxor 0, 0, 0
-; GLOBAL-NEXT:    fcmpu 0, 1, 0
-; GLOBAL-NEXT:    beq 0, .LBB16_2
-; GLOBAL-NEXT:  # %bb.1:
 ; GLOBAL-NEXT:    xsrsqrtesp 0, 1
 ; GLOBAL-NEXT:    vspltisw 2, -3
 ; GLOBAL-NEXT:    addis 3, 2, .LCPI16_0@toc@ha
-; GLOBAL-NEXT:    xvcvsxwdp 2, 34
-; GLOBAL-NEXT:    xsmulsp 1, 1, 0
-; GLOBAL-NEXT:    xsmaddasp 2, 1, 0
+; GLOBAL-NEXT:    xvcvsxwdp 3, 34
+; GLOBAL-NEXT:    xsmulsp 2, 1, 0
+; GLOBAL-NEXT:    xsmaddasp 3, 2, 0
 ; GLOBAL-NEXT:    lfs 0, .LCPI16_0@toc@l(3)
-; GLOBAL-NEXT:    xsmulsp 0, 1, 0
-; GLOBAL-NEXT:    xsmulsp 0, 0, 2
-; GLOBAL-NEXT:  .LBB16_2:
-; GLOBAL-NEXT:    fmr 1, 0
+; GLOBAL-NEXT:    xsmulsp 0, 2, 0
+; GLOBAL-NEXT:    xxlxor 2, 2, 2
+; GLOBAL-NEXT:    xsmulsp 0, 0, 3
+; GLOBAL-NEXT:    fsel 2, 1, 2, 0
+; GLOBAL-NEXT:    xsnegdp 1, 1
+; GLOBAL-NEXT:    fsel 1, 1, 2, 0
 ; GLOBAL-NEXT:    blr
   %rt = call contract reassoc ninf afn float @llvm.sqrt.f32(float %x)
   ret float %rt

virtual SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
const DenormalMode &Mode) const;
const DenormalMode &Mode,
SDNodeFlags Flags) const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SDNodeFlags Flags) const;
SDNodeFlags Flags = {}) const;

Comment on lines 7467 to 7468
// Propagate fast-math flags from fcmp.
Test->setFlags(Flags);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should avoid the node mutation, isn't this available as an argument to getSetCC? I thought setFlag was going away because it breaks on CSE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the only available signature of getSetCC doesn't support SDNodeFlags, I may add support later.

@paperchalice
Copy link
Contributor Author

Oh I see the SelectionDAG::FlagInserter.

@paperchalice paperchalice deleted the fsqrt-test-fmf branch October 14, 2025 07:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants