Skip to content

Conversation

@jph-13
Copy link
Contributor

@jph-13 jph-13 commented May 26, 2025

This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: #91924

This update targets fmul(sitofp(x), C) where C is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have sitofp(X) * C (where C is 1/2^N), this can be optimized to scvtf(X, 2^N). This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.

@llvmbot llvmbot added backend:AArch64 llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms labels May 26, 2025
@llvmbot
Copy link
Member

llvmbot commented May 26, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-aarch64

Author: JP Hafer (jph-13)

Changes

This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: #91924

This update targets fmul(sitofp(x), C) where C is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have sitofp(X) * C (where C is 1/2^N), this can be optimized to scvtf(X, 2^N). This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.


Full diff: https://github.com/llvm/llvm-project/pull/141480.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+152)
  • (added) llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll (+47)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f2800145cc603..bb094d9772c47 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1148,6 +1148,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setTargetDAGCombine({ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
                        ISD::FP_TO_UINT_SAT, ISD::FADD});
 
+  // Try to fmul -> scvtf for powers of 2
+  setTargetDAGCombine(ISD::FMUL);
+
   // Try and combine setcc with csel
   setTargetDAGCombine(ISD::SETCC);
 
@@ -19250,6 +19253,153 @@ static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
   return FixConv;
 }
 
+/// Try to extract a log2 exponent from a uniform constant FP splat.
+/// Returns -1 if the value is not a power-of-two float.
+static int getUniformFPSplatLog2(const BuildVectorSDNode *BV, unsigned MaxExponent) {
+  SDValue FirstElt = BV->getOperand(0);
+  if (!isa<ConstantFPSDNode>(FirstElt))
+    return -1;
+
+  const ConstantFPSDNode *FirstConst = cast<ConstantFPSDNode>(FirstElt);
+  const APFloat &FirstVal = FirstConst->getValueAPF();
+  const fltSemantics &Sem = FirstVal.getSemantics();
+
+  // Check all elements are the same
+  for (unsigned i = 1, e = BV->getNumOperands(); i != e; ++i) {
+    SDValue Elt = BV->getOperand(i);
+    if (!isa<ConstantFPSDNode>(Elt))
+      return -1;
+    const APFloat &Val = cast<ConstantFPSDNode>(Elt)->getValueAPF();
+    if (!Val.bitwiseIsEqual(FirstVal))
+      return -1;
+  }
+
+  // Reject zero, NaN, or negative values
+  if (FirstVal.isZero() || FirstVal.isNaN() || FirstVal.isNegative())
+    return -1;
+
+  // Get raw bits
+  APInt Bits = FirstVal.bitcastToAPInt();
+
+  int ExponentBias = 0;
+  unsigned ExponentBits = 0;
+  unsigned MantissaBits = 0;
+
+  if (&Sem == &APFloat::IEEEsingle()) {
+    ExponentBias = 127;
+    ExponentBits = 8;
+    MantissaBits = 23;
+  } else if (&Sem == &APFloat::IEEEdouble()) {
+    ExponentBias = 1023;
+    ExponentBits = 11;
+    MantissaBits = 52;
+  } else {
+    // Unsupported type
+    return -1;
+  }
+
+  // Mask out mantissa and check it's zero (i.e., power of two)
+  APInt MantissaMask = APInt::getLowBitsSet(Bits.getBitWidth(), MantissaBits);
+  if ((Bits & MantissaMask) != 0)
+    return -1;
+
+  // Extract exponent
+  unsigned ExponentShift = MantissaBits;
+  APInt ExponentMask = APInt::getBitsSet(Bits.getBitWidth(),
+                                         ExponentShift,
+                                         ExponentShift + ExponentBits);
+  int Exponent = (Bits & ExponentMask).lshr(ExponentShift).getZExtValue();
+  int Log2 = ExponentBias - Exponent;
+
+  if (static_cast<unsigned>(Log2) > MaxExponent)
+    return -1;
+
+  return Log2;
+}
+
+/// Fold a floating-point multiply by power of two into fixed-point to
+/// floating-point conversion.
+static SDValue performFMulCombine(SDNode *N, SelectionDAG &DAG,
+                                  TargetLowering::DAGCombinerInfo &DCI,
+                                  const AArch64Subtarget *Subtarget) {
+                                    
+  if (!Subtarget->hasNEON())
+    return SDValue();
+
+  // N is the FMUL node.
+  if (N->getOpcode() != ISD::FMUL)
+      return SDValue();
+
+  // SINT_TO_FP or UINT_TO_FP
+  SDValue Op = N->getOperand(0);
+  unsigned Opc = Op->getOpcode();
+  if (!Op.getValueType().isVector() || !Op.getValueType().isSimple() ||
+      !Op.getOperand(0).getValueType().isSimple() ||
+      (Opc != ISD::SINT_TO_FP && Opc != ISD::UINT_TO_FP))
+    return SDValue();
+
+  SDValue ConstVec = N->getOperand(1);
+  if (!isa<BuildVectorSDNode>(ConstVec))
+    return SDValue();
+
+  MVT IntTy = Op.getOperand(0).getSimpleValueType().getVectorElementType();
+  int32_t IntBits = IntTy.getSizeInBits();
+  if (IntBits != 16 && IntBits != 32 && IntBits != 64)
+    return SDValue();
+
+  MVT FloatTy = N->getSimpleValueType(0).getVectorElementType();
+  int32_t FloatBits = FloatTy.getSizeInBits();
+  if (FloatBits != 32 && FloatBits != 64)
+    return SDValue();
+
+  if (IntBits > FloatBits)
+    return SDValue();
+
+  BitVector UndefElements;
+  BuildVectorSDNode *BV = cast<BuildVectorSDNode>(ConstVec);
+  int32_t IntrinsicC = getUniformFPSplatLog2(BV, FloatBits + 1);
+
+  // Handle cases where it's not a power of two, or is 2^0.
+  if (IntrinsicC == -1 || IntrinsicC == 0)
+    return SDValue();
+
+  // Check if IntrinsicC is within the valid range [1, FloatBits].
+  // The 's' value must be in [1, FloatBits].
+  if (IntrinsicC <= 0 || IntrinsicC > FloatBits)
+      return SDValue();
+
+  MVT ResTy;
+  unsigned NumLanes = Op.getValueType().getVectorNumElements();
+  switch (NumLanes) {
+  default:
+    return SDValue();
+  case 2:
+    ResTy = FloatBits == 32 ? MVT::v2i32 : MVT::v2i64;
+    break;
+  case 4:
+    ResTy = FloatBits == 32 ? MVT::v4i32 : MVT::v4i64;
+    break;
+  }
+
+  if (ResTy == MVT::v4i64 && DCI.isBeforeLegalizeOps())
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue ConvInput = Op.getOperand(0);
+  bool IsSigned = Opc == ISD::SINT_TO_FP;
+
+  if (IntBits < FloatBits)
+    ConvInput = DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL,
+                            ResTy, ConvInput);
+
+  unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfxs2fp
+                                      : Intrinsic::aarch64_neon_vcvtfxu2fp;
+
+  return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
+                     DAG.getConstant(IntrinsicOpcode, DL, MVT::i32), ConvInput,
+                     DAG.getConstant(IntrinsicC, DL, MVT::i32));
+}
+
 static SDValue tryCombineToBSL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                                const AArch64TargetLowering &TLI) {
   EVT VT = N->getValueType(0);
@@ -26693,6 +26843,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::FP_TO_SINT_SAT:
   case ISD::FP_TO_UINT_SAT:
     return performFpToIntCombine(N, DAG, DCI, Subtarget);
+  case ISD::FMUL:
+    return performFMulCombine(N, DAG, DCI, Subtarget);
   case ISD::OR:
     return performORCombine(N, DCI, Subtarget, *this);
   case ISD::AND:
diff --git a/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
new file mode 100644
index 0000000000000..befddb165fcce
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
@@ -0,0 +1,47 @@
+; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-neon-syntax=apple -verify-machineinstrs -o - %s | FileCheck %s
+
+; Test case 1: Scalar fdiv by 16.0
+define float @tests(i32 %in) {
+; CHECK-LABEL: tests:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    scvtf   s0, w0, #4
+; CHECK-NEXT:    ret
+entry:
+  %vcvt.i = sitofp i32 %in to float
+  %div.i = fdiv float %vcvt.i, 16.0
+  ret float %div.i
+}
+
+; Test case 2: Scalar fmul by (2^-4)
+define float @testsmul(i32 %in) local_unnamed_addr #0 {
+; CHECK-LABEL: testsmul:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    scvtf   s0, w0, #4
+; CHECK-NEXT:    ret
+  %vcvt.i = sitofp i32 %in to float
+  %div.i = fmul float %vcvt.i, 6.250000e-02 ; 0.0625 is 2^-4
+  ret float %div.i
+}
+
+; Test case 3: Vector fdiv by 16.0
+define <2 x float> @testv(<2 x i32> %in) {
+; CHECK-LABEL: testv:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    scvtf.2s        v0, v0, #4
+; CHECK-NEXT:    ret
+entry:
+  %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+  %div.i = fdiv <2 x float> %vcvt.i, <float 16.0, float 16.0>
+  ret <2 x float> %div.i
+}
+
+; Test case 4: Vector fmul by 2^-4 
+define <2 x float> @testvmul(<2 x i32> %in) local_unnamed_addr #0 {
+; CHECK-LABEL: testvmul:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    scvtf.2s        v0, v0, #4
+; CHECK-NEXT:    ret
+  %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+  %div.i = fmul <2 x float> %vcvt.i, splat (float 6.250000e-02) ; 0.0625 is 2^-4
+  ret <2 x float> %div.i
+}
\ No newline at end of file

@github-actions
Copy link

github-actions bot commented May 26, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@davemgreen
Copy link
Collaborator

It looks like the existing scalar code uses tablegen patterns, via SelectCVTFixedPosRecipOperand. Would it be possible to do the same for vector operations? It would need to detect the splat constant.

@jph-13
Copy link
Contributor Author

jph-13 commented May 27, 2025

I will take a look at the tablegen impls. Maybe that will help me understand why I am having issues with half too.

@jph-13
Copy link
Contributor Author

jph-13 commented May 30, 2025

I just resolved all the original flags since the new implementation is very different. I did try to get f16 working but I became very confused. As of now it doesn't appear to have a match in TD. I started creating one but I am not sure if I shold replace all the round tripping or not. So I figured I would see if we could get this in, then maybe try another pass at half later.

@jph-13 jph-13 force-pushed the issue_94909 branch 2 times, most recently from fbb93fd to 67a4484 Compare May 30, 2025 18:54
@jph-13
Copy link
Contributor Author

jph-13 commented May 30, 2025

Please ignore for now, not sure what I broke when I squashed. Sorry.

@jph-13
Copy link
Contributor Author

jph-13 commented Jun 10, 2025

This is incomplete, but I could really use some help. I have never touched tablegen before this and feel I am making it way too difficult.

I am having specific issues with the smaller registers and sizes (are my matchers too restrictive?). I left in commented out vNi16 code for I can't get any of it to not conflict. The v1i32 and v1i64 are also escaping me. As for the f16 implementations, it seems I need to handle a sext, but I figure the rest should be working first.

I could use any comments or guidance folks have.

Thanks.

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this got quite far. I've left a few comments inline, tablegen can be finicky at times.

@jph-13 jph-13 marked this pull request as draft June 16, 2025 13:52
@jph-13 jph-13 marked this pull request as ready for review June 18, 2025 15:12
@stephentyrone
Copy link
Contributor

Looks good, thanks for your persistence. Can you add a couple of test cases for examples where we can't do the transform because the power of two is out of range?

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed the update. Nice job getting this working, it looks good. I just have a couple of nitpicks, it looks good to me otherwise.

@jph-13 jph-13 force-pushed the issue_94909 branch 2 times, most recently from e7d1147 to 1cc9c56 Compare June 24, 2025 15:26
Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - I think the handling of MOVIshift might need to be tightened up a bit, otherwise it could ignore important certain bits of the constant. Detecting the constant correctly is the most awkward part of this with all the different combos, but luckily it looks like many of them use a similar node type.

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding all the tests, that's fantastic. I have a suggestion to make the code slightly more structured in case something happens with weird node types, but it LGTM otherwise.

…tf(x, 2)

This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability.
See: llvm#91924

This update targets `fmul(sitofp(x), C)` where `C` is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have `sitofp(X) * C` (where `C` is `1/2^N`), this can be optimized to `scvtf(X, 2^N)`. This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.
@jph-13 jph-13 requested a review from davemgreen July 14, 2025 14:53
@davemgreen
Copy link
Collaborator

Hi - sorry for the delay. My plan was to wait until after the release and then get this committed and see how it does. I'm still not 100% on whether it might accidentally handle a node incorrectly if it has a constant operand.

I was running some extra tests though and ran into this case, where this particular constant gets lowered to an movi I think, but the operand is an i32 (due to type legalization) and the fp type is a fp16. It just needs to use FVal = APFloat(APFloat::IEEEhalf(), Imm.trunc(16)); for fp16 values.

define <4 x half> @test_v4f16_div_const_0xH1c04(<4 x i16> %in) {
; CHECK-LABEL: test_v4f16_div_const_0xH1c04:
; CHECK:       // %bb.0: // %entry
; CHECK-NEXT:    mov w8, #7172 // =0x1c04
; CHECK-NEXT:    scvtf.4h v0, v0
; CHECK-NEXT:    dup.4h v1, w8
; CHECK-NEXT:    fmul.4h v0, v0, v1
; CHECK-NEXT:    ret
entry:
  %vcvt.i = sitofp <4 x i16> %in to <4 x half>
  %div.i = fmul <4 x half> %vcvt.i, <half 0xH1c04, half 0xH1c04, half 0xH1c04, half 0xH1c04>
  ret <4 x half> %div.i
}

I would also recommend changing the start of the function to something like the code below, just to make it super clear how we derive the immediate for each opcode and it is obvious which opcode leads to what.

static bool checkCVTFixedPointOperandWithFBitsForVectors(SelectionDAG *CurDAG,
                                                         SDValue N,
                                                         SDValue &FixedPos,
                                                         unsigned FloatWidth,
                                                         bool IsReciprocal) {
  SDValue ImmediateNode = N;
  if (N.getOpcode() == ISD::BITCAST || N.getOpcode() == AArch64ISD::NVCAST) {
    ImmediateNode = N.getOperand(0);
    // This could have been a bitcast to a scalar
    if (!ImmediateNode.getValueType().isVector())
      return false;
  }

  APInt Imm;
  if (ImmediateNode.getOpcode() == ISD::BUILD_VECTOR) {
    // For BUILD_VECTOR, we must explicitly check if it's a constant splat.
    BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(ImmediateNode.getNode());
    APInt SplatUndef;
    unsigned SplatBitSize;
    bool HasAnyUndefs;
    if (!BVN->isConstantSplat(Imm, SplatUndef, SplatBitSize, HasAnyUndefs) ||
        SplatBitSize != N.getValueType().getScalarSizeInBits())
      return false;
  } else if (ImmediateNode.getOpcode() == AArch64ISD::MOVIshift) {
    EVT NodeVT = N.getValueType();
    Imm = APInt(NodeVT.getScalarSizeInBits(),
                ImmediateNode.getConstantOperandVal(0)
                    << ImmediateNode.getConstantOperandVal(1));
  } else if (ImmediateNode.getOpcode() == AArch64ISD::FMOV) {
    uint8_t EncodedU8 = ImmediateNode.getConstantOperandVal(0);
    uint64_t DecodedBits = AArch64_AM::decodeAdvSIMDModImmType11(EncodedU8);

    unsigned BitWidth = N.getValueType().getVectorElementType().getSizeInBits();
    uint64_t Mask = (BitWidth == 64) ? ~0ULL : ((1ULL << BitWidth) - 1);
    uint64_t MaskedBits = DecodedBits & Mask;

    Imm = APInt(BitWidth, MaskedBits);
  } else if (ImmediateNode.getOpcode() != AArch64ISD::DUP ||
             ImmediateNode.getOpcode() != ISD::SPLAT_VECTOR) {
    auto *CI = dyn_cast<ConstantSDNode>(ImmediateNode.getOperand(0));
    if (!CI)
      return false;
    Imm = CI->getAPIntValue();
  } else {
    return false;
  }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:AArch64 llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants