-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2) #141480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-backend-aarch64 Author: JP Hafer (jph-13) ChangesThis commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: #91924 This update targets Full diff: https://github.com/llvm/llvm-project/pull/141480.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f2800145cc603..bb094d9772c47 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1148,6 +1148,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine({ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
ISD::FP_TO_UINT_SAT, ISD::FADD});
+ // Try to fmul -> scvtf for powers of 2
+ setTargetDAGCombine(ISD::FMUL);
+
// Try and combine setcc with csel
setTargetDAGCombine(ISD::SETCC);
@@ -19250,6 +19253,153 @@ static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
return FixConv;
}
+/// Try to extract a log2 exponent from a uniform constant FP splat.
+/// Returns -1 if the value is not a power-of-two float.
+static int getUniformFPSplatLog2(const BuildVectorSDNode *BV, unsigned MaxExponent) {
+ SDValue FirstElt = BV->getOperand(0);
+ if (!isa<ConstantFPSDNode>(FirstElt))
+ return -1;
+
+ const ConstantFPSDNode *FirstConst = cast<ConstantFPSDNode>(FirstElt);
+ const APFloat &FirstVal = FirstConst->getValueAPF();
+ const fltSemantics &Sem = FirstVal.getSemantics();
+
+ // Check all elements are the same
+ for (unsigned i = 1, e = BV->getNumOperands(); i != e; ++i) {
+ SDValue Elt = BV->getOperand(i);
+ if (!isa<ConstantFPSDNode>(Elt))
+ return -1;
+ const APFloat &Val = cast<ConstantFPSDNode>(Elt)->getValueAPF();
+ if (!Val.bitwiseIsEqual(FirstVal))
+ return -1;
+ }
+
+ // Reject zero, NaN, or negative values
+ if (FirstVal.isZero() || FirstVal.isNaN() || FirstVal.isNegative())
+ return -1;
+
+ // Get raw bits
+ APInt Bits = FirstVal.bitcastToAPInt();
+
+ int ExponentBias = 0;
+ unsigned ExponentBits = 0;
+ unsigned MantissaBits = 0;
+
+ if (&Sem == &APFloat::IEEEsingle()) {
+ ExponentBias = 127;
+ ExponentBits = 8;
+ MantissaBits = 23;
+ } else if (&Sem == &APFloat::IEEEdouble()) {
+ ExponentBias = 1023;
+ ExponentBits = 11;
+ MantissaBits = 52;
+ } else {
+ // Unsupported type
+ return -1;
+ }
+
+ // Mask out mantissa and check it's zero (i.e., power of two)
+ APInt MantissaMask = APInt::getLowBitsSet(Bits.getBitWidth(), MantissaBits);
+ if ((Bits & MantissaMask) != 0)
+ return -1;
+
+ // Extract exponent
+ unsigned ExponentShift = MantissaBits;
+ APInt ExponentMask = APInt::getBitsSet(Bits.getBitWidth(),
+ ExponentShift,
+ ExponentShift + ExponentBits);
+ int Exponent = (Bits & ExponentMask).lshr(ExponentShift).getZExtValue();
+ int Log2 = ExponentBias - Exponent;
+
+ if (static_cast<unsigned>(Log2) > MaxExponent)
+ return -1;
+
+ return Log2;
+}
+
+/// Fold a floating-point multiply by power of two into fixed-point to
+/// floating-point conversion.
+static SDValue performFMulCombine(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const AArch64Subtarget *Subtarget) {
+
+ if (!Subtarget->hasNEON())
+ return SDValue();
+
+ // N is the FMUL node.
+ if (N->getOpcode() != ISD::FMUL)
+ return SDValue();
+
+ // SINT_TO_FP or UINT_TO_FP
+ SDValue Op = N->getOperand(0);
+ unsigned Opc = Op->getOpcode();
+ if (!Op.getValueType().isVector() || !Op.getValueType().isSimple() ||
+ !Op.getOperand(0).getValueType().isSimple() ||
+ (Opc != ISD::SINT_TO_FP && Opc != ISD::UINT_TO_FP))
+ return SDValue();
+
+ SDValue ConstVec = N->getOperand(1);
+ if (!isa<BuildVectorSDNode>(ConstVec))
+ return SDValue();
+
+ MVT IntTy = Op.getOperand(0).getSimpleValueType().getVectorElementType();
+ int32_t IntBits = IntTy.getSizeInBits();
+ if (IntBits != 16 && IntBits != 32 && IntBits != 64)
+ return SDValue();
+
+ MVT FloatTy = N->getSimpleValueType(0).getVectorElementType();
+ int32_t FloatBits = FloatTy.getSizeInBits();
+ if (FloatBits != 32 && FloatBits != 64)
+ return SDValue();
+
+ if (IntBits > FloatBits)
+ return SDValue();
+
+ BitVector UndefElements;
+ BuildVectorSDNode *BV = cast<BuildVectorSDNode>(ConstVec);
+ int32_t IntrinsicC = getUniformFPSplatLog2(BV, FloatBits + 1);
+
+ // Handle cases where it's not a power of two, or is 2^0.
+ if (IntrinsicC == -1 || IntrinsicC == 0)
+ return SDValue();
+
+ // Check if IntrinsicC is within the valid range [1, FloatBits].
+ // The 's' value must be in [1, FloatBits].
+ if (IntrinsicC <= 0 || IntrinsicC > FloatBits)
+ return SDValue();
+
+ MVT ResTy;
+ unsigned NumLanes = Op.getValueType().getVectorNumElements();
+ switch (NumLanes) {
+ default:
+ return SDValue();
+ case 2:
+ ResTy = FloatBits == 32 ? MVT::v2i32 : MVT::v2i64;
+ break;
+ case 4:
+ ResTy = FloatBits == 32 ? MVT::v4i32 : MVT::v4i64;
+ break;
+ }
+
+ if (ResTy == MVT::v4i64 && DCI.isBeforeLegalizeOps())
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue ConvInput = Op.getOperand(0);
+ bool IsSigned = Opc == ISD::SINT_TO_FP;
+
+ if (IntBits < FloatBits)
+ ConvInput = DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL,
+ ResTy, ConvInput);
+
+ unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfxs2fp
+ : Intrinsic::aarch64_neon_vcvtfxu2fp;
+
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
+ DAG.getConstant(IntrinsicOpcode, DL, MVT::i32), ConvInput,
+ DAG.getConstant(IntrinsicC, DL, MVT::i32));
+}
+
static SDValue tryCombineToBSL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64TargetLowering &TLI) {
EVT VT = N->getValueType(0);
@@ -26693,6 +26843,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::FP_TO_SINT_SAT:
case ISD::FP_TO_UINT_SAT:
return performFpToIntCombine(N, DAG, DCI, Subtarget);
+ case ISD::FMUL:
+ return performFMulCombine(N, DAG, DCI, Subtarget);
case ISD::OR:
return performORCombine(N, DCI, Subtarget, *this);
case ISD::AND:
diff --git a/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
new file mode 100644
index 0000000000000..befddb165fcce
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
@@ -0,0 +1,47 @@
+; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-neon-syntax=apple -verify-machineinstrs -o - %s | FileCheck %s
+
+; Test case 1: Scalar fdiv by 16.0
+define float @tests(i32 %in) {
+; CHECK-LABEL: tests:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: scvtf s0, w0, #4
+; CHECK-NEXT: ret
+entry:
+ %vcvt.i = sitofp i32 %in to float
+ %div.i = fdiv float %vcvt.i, 16.0
+ ret float %div.i
+}
+
+; Test case 2: Scalar fmul by (2^-4)
+define float @testsmul(i32 %in) local_unnamed_addr #0 {
+; CHECK-LABEL: testsmul:
+; CHECK: // %bb.0:
+; CHECK-NEXT: scvtf s0, w0, #4
+; CHECK-NEXT: ret
+ %vcvt.i = sitofp i32 %in to float
+ %div.i = fmul float %vcvt.i, 6.250000e-02 ; 0.0625 is 2^-4
+ ret float %div.i
+}
+
+; Test case 3: Vector fdiv by 16.0
+define <2 x float> @testv(<2 x i32> %in) {
+; CHECK-LABEL: testv:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: scvtf.2s v0, v0, #4
+; CHECK-NEXT: ret
+entry:
+ %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+ %div.i = fdiv <2 x float> %vcvt.i, <float 16.0, float 16.0>
+ ret <2 x float> %div.i
+}
+
+; Test case 4: Vector fmul by 2^-4
+define <2 x float> @testvmul(<2 x i32> %in) local_unnamed_addr #0 {
+; CHECK-LABEL: testvmul:
+; CHECK: // %bb.0:
+; CHECK-NEXT: scvtf.2s v0, v0, #4
+; CHECK-NEXT: ret
+ %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+ %div.i = fmul <2 x float> %vcvt.i, splat (float 6.250000e-02) ; 0.0625 is 2^-4
+ ret <2 x float> %div.i
+}
\ No newline at end of file
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
It looks like the existing scalar code uses tablegen patterns, via SelectCVTFixedPosRecipOperand. Would it be possible to do the same for vector operations? It would need to detect the splat constant. |
|
I will take a look at the tablegen impls. Maybe that will help me understand why I am having issues with half too. |
|
I just resolved all the original flags since the new implementation is very different. I did try to get f16 working but I became very confused. As of now it doesn't appear to have a match in TD. I started creating one but I am not sure if I shold replace all the round tripping or not. So I figured I would see if we could get this in, then maybe try another pass at half later. |
fbb93fd to
67a4484
Compare
|
Please ignore for now, not sure what I broke when I squashed. Sorry. |
|
This is incomplete, but I could really use some help. I have never touched tablegen before this and feel I am making it way too difficult. I am having specific issues with the smaller registers and sizes (are my matchers too restrictive?). I left in commented out vNi16 code for I can't get any of it to not conflict. The v1i32 and v1i64 are also escaping me. As for the f16 implementations, it seems I need to handle a sext, but I figure the rest should be working first. I could use any comments or guidance folks have. Thanks. |
davemgreen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this got quite far. I've left a few comments inline, tablegen can be finicky at times.
llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
Outdated
Show resolved
Hide resolved
|
Looks good, thanks for your persistence. Can you add a couple of test cases for examples where we can't do the transform because the power of two is out of range? |
davemgreen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed the update. Nice job getting this working, it looks good. I just have a couple of nitpicks, it looks good to me otherwise.
e7d1147 to
1cc9c56
Compare
davemgreen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - I think the handling of MOVIshift might need to be tightened up a bit, otherwise it could ignore important certain bits of the constant. Detecting the constant correctly is the most awkward part of this with all the different combos, but luckily it looks like many of them use a similar node type.
davemgreen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding all the tests, that's fantastic. I have a suggestion to make the code slightly more structured in case something happens with weird node types, but it LGTM otherwise.
…tf(x, 2) This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: llvm#91924 This update targets `fmul(sitofp(x), C)` where `C` is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have `sitofp(X) * C` (where `C` is `1/2^N`), this can be optimized to `scvtf(X, 2^N)`. This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.
|
Hi - sorry for the delay. My plan was to wait until after the release and then get this committed and see how it does. I'm still not 100% on whether it might accidentally handle a node incorrectly if it has a constant operand. I was running some extra tests though and ran into this case, where this particular constant gets lowered to an movi I think, but the operand is an i32 (due to type legalization) and the fp type is a fp16. It just needs to use I would also recommend changing the start of the function to something like the code below, just to make it super clear how we derive the immediate for each opcode and it is obvious which opcode leads to what. |
This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: #91924
This update targets
fmul(sitofp(x), C)whereCis a constant reciprocal of a power of two. For both scalar and vector inputs, if we havesitofp(X) * C(whereCis1/2^N), this can be optimized toscvtf(X, 2^N). This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.