Skip to content

Commit 8c4f744

Browse files
committed
[AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2)
This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: #91924 This update targets `fmul(sitofp(x), C)` where `C` is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have `sitofp(X) * C` (where `C` is `1/2^N`), this can be optimized to `scvtf(X, 2^N)`. This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.
1 parent d57b867 commit 8c4f744

File tree

3 files changed

+676
-0
lines changed

3 files changed

+676
-0
lines changed

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,14 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
487487
bool SelectCVTFixedPosRecipOperand(SDValue N, SDValue &FixedPos,
488488
unsigned Width);
489489

490+
template <unsigned FloatWidth>
491+
bool SelectCVTFixedPosRecipOperandVec(SDValue N, SDValue &FixedPos) {
492+
return SelectCVTFixedPosRecipOperandVec(N, FixedPos, FloatWidth);
493+
}
494+
495+
bool SelectCVTFixedPosRecipOperandVec(SDValue N, SDValue &FixedPos,
496+
unsigned Width);
497+
490498
bool SelectCMP_SWAP(SDNode *N);
491499

492500
bool SelectSVEAddSubImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift);
@@ -3952,6 +3960,129 @@ static bool checkCVTFixedPointOperandWithFBits(SelectionDAG *CurDAG, SDValue N,
39523960
return true;
39533961
}
39543962

3963+
static bool checkCVTFixedPointOperandWithFBitsForVectors(SelectionDAG *CurDAG,
3964+
SDValue N,
3965+
SDValue &FixedPos,
3966+
unsigned FloatWidth,
3967+
bool IsReciprocal) {
3968+
3969+
if (N->getNumOperands() < 1)
3970+
return false;
3971+
3972+
SDValue ImmediateNode = N.getOperand(0);
3973+
if (N.getOpcode() == ISD::BITCAST || N.getOpcode() == AArch64ISD::NVCAST) {
3974+
// This could have been a bitcast to a scalar
3975+
if (!ImmediateNode.getValueType().isVector())
3976+
return false;
3977+
}
3978+
3979+
if (!(ImmediateNode.getOpcode() == AArch64ISD::DUP ||
3980+
ImmediateNode.getOpcode() == AArch64ISD::MOVIshift ||
3981+
ImmediateNode.getOpcode() == ISD::BUILD_VECTOR ||
3982+
ImmediateNode.getOpcode() == ISD::Constant ||
3983+
ImmediateNode.getOpcode() == ISD::SPLAT_VECTOR)) {
3984+
return false;
3985+
}
3986+
3987+
if (ImmediateNode.getOpcode() != ISD::Constant) {
3988+
auto *C = dyn_cast<ConstantSDNode>(ImmediateNode.getOperand(0));
3989+
if (!C)
3990+
return false;
3991+
}
3992+
3993+
if (ImmediateNode.getOpcode() == ISD::BUILD_VECTOR) {
3994+
// For BUILD_VECTOR, we must explicitly check if it's a constant splat.
3995+
BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(ImmediateNode.getNode());
3996+
APInt SplatValue;
3997+
APInt SplatUndef;
3998+
unsigned SplatBitSize;
3999+
bool HasAnyUndefs;
4000+
if (!BVN->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
4001+
HasAnyUndefs)) {
4002+
return false;
4003+
}
4004+
}
4005+
4006+
APInt Imm;
4007+
bool IsIntConstant = false;
4008+
if (ImmediateNode.getOpcode() == AArch64ISD::MOVIshift) {
4009+
EVT NodeVT = N.getValueType();
4010+
Imm = APInt(NodeVT.getScalarSizeInBits(),
4011+
ImmediateNode.getConstantOperandVal(0)
4012+
<< ImmediateNode.getConstantOperandVal(1));
4013+
IsIntConstant = true;
4014+
} else if (ImmediateNode.getOpcode() == ISD::Constant) {
4015+
auto *C = dyn_cast<ConstantSDNode>(ImmediateNode);
4016+
if (!C)
4017+
return false;
4018+
uint8_t EncodedU8 = static_cast<uint8_t>(C->getZExtValue());
4019+
uint64_t DecodedBits = AArch64_AM::decodeAdvSIMDModImmType11(EncodedU8);
4020+
4021+
unsigned BitWidth = N.getValueType().getVectorElementType().getSizeInBits();
4022+
uint64_t Mask = (BitWidth == 64) ? ~0ULL : ((1ULL << BitWidth) - 1);
4023+
uint64_t MaskedBits = DecodedBits & Mask;
4024+
4025+
Imm = APInt(BitWidth, MaskedBits);
4026+
IsIntConstant = true;
4027+
} else if (auto *CI = dyn_cast<ConstantSDNode>(ImmediateNode.getOperand(0))) {
4028+
Imm = CI->getAPIntValue();
4029+
IsIntConstant = true;
4030+
}
4031+
4032+
APFloat FVal(0.0);
4033+
// --- Extract the actual constant value ---
4034+
if (IsIntConstant) {
4035+
// Scalar source is an integer constant; interpret its bits as
4036+
// floating-point.
4037+
EVT FloatEltVT = N.getValueType().getVectorElementType();
4038+
4039+
if (FloatEltVT == MVT::f32) {
4040+
FVal = APFloat(APFloat::IEEEsingle(), Imm);
4041+
} else if (FloatEltVT == MVT::f64) {
4042+
FVal = APFloat(APFloat::IEEEdouble(), Imm);
4043+
} else if (FloatEltVT == MVT::f16) {
4044+
FVal = APFloat(APFloat::IEEEhalf(), Imm);
4045+
} else {
4046+
// Unsupported floating-point element type.
4047+
return false;
4048+
}
4049+
} else {
4050+
// ScalarSourceNode is not a recognized constant type.
4051+
return false;
4052+
}
4053+
4054+
// Handle reciprocal case.
4055+
if (IsReciprocal) {
4056+
if (!FVal.getExactInverse(&FVal))
4057+
// Not an exact reciprocal, or reciprocal not a power of 2.
4058+
return false;
4059+
}
4060+
4061+
bool IsExact;
4062+
unsigned TargetIntBits =
4063+
N.getValueType().getVectorElementType().getSizeInBits();
4064+
APSInt IntVal(
4065+
TargetIntBits + 1,
4066+
true); // Use TargetIntBits + 1 for sufficient bits for conversion
4067+
4068+
FVal.convertToInteger(IntVal, APFloat::rmTowardZero, &IsExact);
4069+
4070+
if (!IsExact || !IntVal.isPowerOf2())
4071+
return false;
4072+
4073+
unsigned FBits = IntVal.logBase2();
4074+
// FBits must be non-zero (implies actual scaling) and within the range
4075+
// supported by the instruction (typically 1 to 64 for AArch64 FCVTZS/FCVTZU).
4076+
// FloatWidth should ideally be the width of the *integer elements* in the
4077+
// vector (16, 32, 64).
4078+
if (FBits == 0 || FBits > FloatWidth)
4079+
return false;
4080+
4081+
// Set FixedPos to the extracted FBits as an i32 constant SDValue.
4082+
FixedPos = CurDAG->getTargetConstant(FBits, SDLoc(N), MVT::i32);
4083+
return true;
4084+
}
4085+
39554086
bool AArch64DAGToDAGISel::SelectCVTFixedPosOperand(SDValue N, SDValue &FixedPos,
39564087
unsigned RegWidth) {
39574088
return checkCVTFixedPointOperandWithFBits(CurDAG, N, FixedPos, RegWidth,
@@ -3965,6 +4096,12 @@ bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperand(SDValue N,
39654096
true);
39664097
}
39674098

4099+
bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperandVec(
4100+
SDValue N, SDValue &FixedPos, unsigned FloatWidth) {
4101+
return checkCVTFixedPointOperandWithFBitsForVectors(CurDAG, N, FixedPos,
4102+
FloatWidth, true);
4103+
}
4104+
39684105
// Inspects a register string of the form o0:op1:CRn:CRm:op2 gets the fields
39694106
// of the string and obtains the integer values from them and combines these
39704107
// into a single value to be used in the MRS/MSR instruction.

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8473,6 +8473,58 @@ def : Pat<(v8f16 (sint_to_fp (v8i16 (AArch64vashr_exact v8i16:$Vn, i32:$shift)))
84738473
(SCVTFv8i16_shift $Vn, vecshiftR16:$shift)>;
84748474
}
84758475

8476+
// Select fmul(sitofp(x), C) where C is a constant reciprocal of a power of two.
8477+
// For both scalar and vector inputs, if we have sitofp(X) * C (where C is
8478+
// 1/2^N), this can be optimized to scvtf(X, 2^N).
8479+
class fixedpoint_recip_vec_i16<ValueType FloatVT>
8480+
: ComplexPattern<FloatVT, 1, "SelectCVTFixedPosRecipOperandVec<16>", []>;
8481+
class fixedpoint_recip_vec_i32<ValueType FloatVT>
8482+
: ComplexPattern<FloatVT, 1, "SelectCVTFixedPosRecipOperandVec<32>", []>;
8483+
class fixedpoint_recip_vec_i64<ValueType FloatVT>
8484+
: ComplexPattern<FloatVT, 1, "SelectCVTFixedPosRecipOperandVec<64>", []>;
8485+
def fixedpoint_recip_vec_xform : SDNodeXForm<timm, [{
8486+
// Suppress the unused variable warning by explicitly using N.
8487+
// The actual value needed for the pattern is already in V.
8488+
(void)N;
8489+
return V;
8490+
}]>;
8491+
8492+
def fixedpoint_recip_v2f32_v2i32 : fixedpoint_recip_vec_i32<v2f32>;
8493+
def fixedpoint_recip_v4f32_v4i32 : fixedpoint_recip_vec_i32<v4f32>;
8494+
def fixedpoint_recip_v2f64_v2i64 : fixedpoint_recip_vec_i64<v2f64>;
8495+
8496+
def fixedpoint_recip_v4f16_v4i16 : fixedpoint_recip_vec_i16<v4f16>;
8497+
def fixedpoint_recip_v8f16_v8i16 : fixedpoint_recip_vec_i16<v8f16>;
8498+
8499+
let Predicates = [HasNEON] in {
8500+
def : Pat<(v2f32(fmul(sint_to_fp(v2i32 V64:$Rn)),
8501+
fixedpoint_recip_v2f32_v2i32:$scale)),
8502+
(v2f32(SCVTFv2i32_shift(v2i32 V64:$Rn),
8503+
(fixedpoint_recip_vec_xform fixedpoint_recip_v2f32_v2i32:$scale)))>;
8504+
8505+
def : Pat<(v4f32(fmul(sint_to_fp(v4i32 FPR128:$Rn)),
8506+
fixedpoint_recip_v4f32_v4i32:$scale)),
8507+
(v4f32(SCVTFv4i32_shift(v4i32 FPR128:$Rn),
8508+
(fixedpoint_recip_vec_xform fixedpoint_recip_v4f32_v4i32:$scale)))>;
8509+
8510+
def : Pat<(v2f64(fmul(sint_to_fp(v2i64 FPR128:$Rn)),
8511+
fixedpoint_recip_v2f64_v2i64:$scale)),
8512+
(v2f64(SCVTFv2i64_shift(v2i64 FPR128:$Rn),
8513+
(fixedpoint_recip_vec_xform fixedpoint_recip_v2f64_v2i64:$scale)))>;
8514+
}
8515+
8516+
let Predicates = [HasNEON, HasFullFP16] in {
8517+
def : Pat<(v4f16(fmul(sint_to_fp(v4i16 V64:$Rn)),
8518+
fixedpoint_recip_v4f16_v4i16:$scale)),
8519+
(v4f16(SCVTFv4i16_shift(v4i16 V64:$Rn),
8520+
(fixedpoint_recip_vec_xform fixedpoint_recip_v4f16_v4i16:$scale)))>;
8521+
8522+
def : Pat<(v8f16(fmul(sint_to_fp(v8i16 FPR128:$Rn)),
8523+
fixedpoint_recip_v8f16_v8i16:$scale)),
8524+
(v8f16(SCVTFv8i16_shift(v8i16 FPR128:$Rn),
8525+
(fixedpoint_recip_vec_xform fixedpoint_recip_v8f16_v8i16:$scale)))>;
8526+
}
8527+
84768528
// X << 1 ==> X + X
84778529
class SHLToADDPat<ValueType ty, RegisterClass regtype>
84788530
: Pat<(ty (AArch64vshl (ty regtype:$Rn), (i32 1))),

0 commit comments

Comments
 (0)