Skip to content

Commit ef6dbc6

Browse files
committed
[AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2)
This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: #91924 This update targets `fmul(sitofp(x), C)` where `C` is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have `sitofp(X) * C` (where `C` is `1/2^N`), this can be optimized to `scvtf(X, 2^N)`. This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.
1 parent 1c3320c commit ef6dbc6

File tree

3 files changed

+358
-0
lines changed

3 files changed

+358
-0
lines changed

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,14 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
487487
bool SelectCVTFixedPosRecipOperand(SDValue N, SDValue &FixedPos,
488488
unsigned Width);
489489

490+
template <unsigned RegWidth>
491+
bool SelectCVTFixedPosRecipOperandVec(SDValue N, SDValue &FixedPos) {
492+
return SelectCVTFixedPosRecipOperandVec(N, FixedPos, RegWidth);
493+
}
494+
495+
bool SelectCVTFixedPosRecipOperandVec(SDValue N, SDValue &FixedPos,
496+
unsigned Width);
497+
490498
bool SelectCMP_SWAP(SDNode *N);
491499

492500
bool SelectSVEAddSubImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift);
@@ -3952,6 +3960,156 @@ static bool checkCVTFixedPointOperandWithFBits(SelectionDAG *CurDAG, SDValue N,
39523960
return true;
39533961
}
39543962

3963+
static bool checkCVTFixedPointOperandWithFBitsForVectors(SelectionDAG *CurDAG,
3964+
SDValue N,
3965+
SDValue &FixedPos,
3966+
unsigned RegWidth,
3967+
bool isReciprocal) {
3968+
3969+
// Fast Path
3970+
if (N.getOpcode() == ISD::BUILD_VECTOR) {
3971+
// Match build_vector <float C, float C, ...>
3972+
unsigned NumElts = N.getNumOperands();
3973+
ConstantFPSDNode *First = dyn_cast<ConstantFPSDNode>(N.getOperand(0));
3974+
if (!First)
3975+
return false;
3976+
3977+
APFloat FVal = First->getValueAPF();
3978+
for (unsigned i = 1; i < NumElts; ++i) {
3979+
ConstantFPSDNode *CFP = dyn_cast<ConstantFPSDNode>(N.getOperand(i));
3980+
if (!CFP || !CFP->isExactlyValue(FVal))
3981+
return false;
3982+
}
3983+
3984+
if (N.getValueType().getVectorElementType() == MVT::f16) {
3985+
bool ignored;
3986+
FVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
3987+
&ignored);
3988+
}
3989+
3990+
if (isReciprocal) {
3991+
if (!FVal.getExactInverse(&FVal))
3992+
return false;
3993+
}
3994+
3995+
bool IsExact;
3996+
APSInt IntVal(RegWidth + 1, true);
3997+
FVal.convertToInteger(IntVal, APFloat::rmTowardZero, &IsExact);
3998+
3999+
if (!IsExact || !IntVal.isPowerOf2())
4000+
return false;
4001+
4002+
unsigned FBits = IntVal.logBase2();
4003+
if (FBits == 0 || FBits > RegWidth)
4004+
return false;
4005+
4006+
FixedPos = CurDAG->getTargetConstant(FBits, SDLoc(N), MVT::i32);
4007+
return true;
4008+
}
4009+
4010+
// N must be ISD::BITCAST and convert a vector integer type to a vector float
4011+
// type.
4012+
if (N.getOpcode() != ISD::BITCAST || !N.getValueType().isVector() ||
4013+
!N.getValueType().isFloatingPoint()) {
4014+
return false;
4015+
}
4016+
SDValue VectorIntNode = N.getOperand(
4017+
0); // This is the v2i32 node (t16 in your DAG), likely AArch64ISD::DUP
4018+
4019+
// The source of the bitcast must be a splat-forming operation from a
4020+
// constant.
4021+
SDValue ScalarSourceNode;
4022+
bool isSplatConfirmed = false;
4023+
4024+
if (VectorIntNode.getOpcode() == AArch64ISD::DUP) {
4025+
// AArch64ISD::DUP inherently means a splat of its scalar operand.
4026+
ScalarSourceNode = VectorIntNode.getOperand(0);
4027+
isSplatConfirmed = true;
4028+
} else if (VectorIntNode.getOpcode() == ISD::SPLAT_VECTOR) {
4029+
ScalarSourceNode = VectorIntNode.getOperand(0);
4030+
isSplatConfirmed = true;
4031+
} else if (VectorIntNode.getOpcode() == ISD::BUILD_VECTOR) {
4032+
// For ISD::BUILD_VECTOR, we must explicitly check if it's a constant splat.
4033+
BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(VectorIntNode.getNode());
4034+
APInt SplatValue;
4035+
APInt SplatUndef;
4036+
unsigned SplatBitSize;
4037+
bool HasAnyUndefs;
4038+
if (BVN->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
4039+
HasAnyUndefs)) {
4040+
ScalarSourceNode = VectorIntNode.getOperand(0);
4041+
isSplatConfirmed = true;
4042+
return false; // BUILD_VECTOR was not a splat
4043+
}
4044+
} else {
4045+
// The node below the bitcast is not a recognized splat-forming node.
4046+
return false;
4047+
}
4048+
4049+
if (!isSplatConfirmed)
4050+
return false;
4051+
4052+
// ScalarSourceNode must be a constant (ISD::Constant or ISD::ConstantFP).
4053+
APFloat FVal(0.0);
4054+
if (auto *CFP = dyn_cast<ConstantFPSDNode>(ScalarSourceNode)) {
4055+
FVal = CFP->getValueAPF();
4056+
} else if (auto *CI = dyn_cast<ConstantSDNode>(ScalarSourceNode)) {
4057+
// If it's an integer constant, interpret its bits as a floating-point
4058+
// value. The target float element type is from
4059+
// N.getValueType().getVectorElementType()
4060+
EVT FloatEltVT = N.getValueType().getVectorElementType();
4061+
4062+
if (FloatEltVT == MVT::f32) {
4063+
FVal = APFloat(APFloat::IEEEsingle(), CI->getAPIntValue());
4064+
} else if (FloatEltVT == MVT::f64) {
4065+
FVal = APFloat(APFloat::IEEEdouble(), CI->getAPIntValue());
4066+
} else if (FloatEltVT == MVT::f16) {
4067+
FVal = APFloat(APFloat::IEEEhalf(), CI->getAPIntValue());
4068+
} else {
4069+
return false;
4070+
}
4071+
} else {
4072+
return false;
4073+
}
4074+
4075+
// 4. Perform fixed-point reciprocal check and power-of-2 validation on FVal.
4076+
// Normalize f16 to f32 if needed for consistent APFloat operations (if
4077+
// VecFloatVT was v2f16).
4078+
if (N.getValueType().getVectorElementType() == MVT::f16) {
4079+
bool ignored;
4080+
FVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
4081+
}
4082+
4083+
// Handle reciprocal case if applicable for this fixed-point conversion.
4084+
if (isReciprocal) {
4085+
if (!FVal.getExactInverse(&FVal))
4086+
return false;
4087+
}
4088+
4089+
bool IsExact;
4090+
// RegWidth is the width of the floating point element type (e.g., 32 for f32,
4091+
// 64 for f64).
4092+
APSInt IntVal(RegWidth + 1,
4093+
true); // Use RegWidth + 1 for sufficient bits for conversion
4094+
FVal.convertToInteger(IntVal, APFloat::rmTowardZero, &IsExact);
4095+
4096+
if (!IsExact || !IntVal.isPowerOf2())
4097+
return false;
4098+
4099+
unsigned FBits = IntVal.logBase2();
4100+
// FBits must be non-zero and within the expected range for the instruction's
4101+
// scale field. The scale field is 6 bits, so FBits must be <= 63.
4102+
if (FBits == 0 ||
4103+
FBits > RegWidth) // FBits should fit within the float's precision
4104+
return false;
4105+
4106+
// 5. Set FixedPos to the extracted FBits as an i32 constant SDValue.
4107+
// This is the i32 immediate that the SCVTF instruction's 'scale' operand
4108+
// expects.
4109+
FixedPos = CurDAG->getTargetConstant(FBits, SDLoc(N), MVT::i32);
4110+
return true;
4111+
}
4112+
39554113
bool AArch64DAGToDAGISel::SelectCVTFixedPosOperand(SDValue N, SDValue &FixedPos,
39564114
unsigned RegWidth) {
39574115
return checkCVTFixedPointOperandWithFBits(CurDAG, N, FixedPos, RegWidth,
@@ -3965,6 +4123,13 @@ bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperand(SDValue N,
39654123
true);
39664124
}
39674125

4126+
bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperandVec(SDValue N,
4127+
SDValue &FixedPos,
4128+
unsigned RegWidth) {
4129+
return checkCVTFixedPointOperandWithFBitsForVectors(CurDAG, N, FixedPos,
4130+
RegWidth, true);
4131+
}
4132+
39684133
// Inspects a register string of the form o0:op1:CRn:CRm:op2 gets the fields
39694134
// of the string and obtains the integer values from them and combines these
39704135
// into a single value to be used in the MRS/MSR instruction.

llvm/lib/Target/AArch64/AArch64InstrFormats.td

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,22 @@ class fixedpoint_recip_i64<ValueType FloatVT>
799799
let DecoderMethod = "DecodeFixedPointScaleImm64";
800800
}
801801

802+
class fixedpoint_recip_vec_i32<ValueType VecFloatVT>
803+
: Operand<VecFloatVT>,
804+
ComplexPattern<VecFloatVT, 1,
805+
"SelectCVTFixedPosRecipOperandVec<32>", [build_vector]> {
806+
let EncoderMethod = "getFixedPointScaleOpValue";
807+
let DecoderMethod = "DecodeFixedPointScaleImm32";
808+
}
809+
810+
class fixedpoint_recip_vec_i64<ValueType VecFloatVT>
811+
: Operand<VecFloatVT>,
812+
ComplexPattern<VecFloatVT, 1,
813+
"SelectCVTFixedPosRecipOperandVec<64>", [build_vector]> {
814+
let EncoderMethod = "getFixedPointScaleOpValue";
815+
let DecoderMethod = "DecodeFixedPointScaleImm32";
816+
}
817+
802818
def fixedpoint_recip_f16_i32 : fixedpoint_recip_i32<f16>;
803819
def fixedpoint_recip_f32_i32 : fixedpoint_recip_i32<f32>;
804820
def fixedpoint_recip_f64_i32 : fixedpoint_recip_i32<f64>;
@@ -807,6 +823,16 @@ def fixedpoint_recip_f16_i64 : fixedpoint_recip_i64<f16>;
807823
def fixedpoint_recip_f32_i64 : fixedpoint_recip_i64<f32>;
808824
def fixedpoint_recip_f64_i64 : fixedpoint_recip_i64<f64>;
809825

826+
def fixedpoint_recip_v2f16_v2i32 : fixedpoint_recip_vec_i32<v2f16>;
827+
def fixedpoint_recip_v4f16_v4i32 : fixedpoint_recip_vec_i32<v4f16>;
828+
def fixedpoint_recip_v8f16_v8i32 : fixedpoint_recip_vec_i32<v8f16>;
829+
def fixedpoint_recip_v2f32_v2i32 : fixedpoint_recip_vec_i32<v2f32>;
830+
def fixedpoint_recip_v4f32_v4i32 : fixedpoint_recip_vec_i32<v4f32>;
831+
832+
def fixedpoint_recip_v2f16_v2i64 : fixedpoint_recip_vec_i64<v2f16>;
833+
def fixedpoint_recip_v2f32_v2i64 : fixedpoint_recip_vec_i64<v2f32>;
834+
def fixedpoint_recip_v2f64_v2i64 : fixedpoint_recip_vec_i64<v2f64>;
835+
810836
def vecshiftR8 : Operand<i32>, ImmLeaf<i32, [{
811837
return (((uint32_t)Imm) > 0) && (((uint32_t)Imm) < 9);
812838
}]> {
@@ -5407,6 +5433,102 @@ class BaseIntegerToFPUnscaled<bits<2> rmode, bits<3> opcode,
54075433
let Inst{4-0} = Rd;
54085434
}
54095435

5436+
multiclass IntegerToFPVector<
5437+
bits<2> rmode, bits<3> opcode, string asm, RegisterClass srcRegClass,
5438+
RegisterClass dstRegClass, Operand imm_op, bits<1> q, bits<2> size,
5439+
bits<2> srcElemTypeBits, list<Predicate> preds> {
5440+
5441+
def _V : BaseIntegerToFP<rmode, opcode, srcRegClass, dstRegClass, imm_op,
5442+
asm, []> {
5443+
let Inst{30} = q;
5444+
let Inst{23 -22} = size;
5445+
let Inst{18 -16} = 0b001;
5446+
let Inst{11 -10} = srcElemTypeBits;
5447+
let Predicates = preds;
5448+
}
5449+
}
5450+
5451+
// SCVTF (Signed Convert To Floating-Point) from Vector 32-bit Integer (vNi32)
5452+
// defm SCVTFv2f16_v2i32 : IntegerToFPVector<0b00, 0b010, "scvtf",
5453+
// FPR64, FPR64,
5454+
// fixedpoint_recip_v2f16_v2i32,
5455+
// 0, 0b00, 0b10, [HasFullFP16]>;
5456+
5457+
// defm SCVTFv4f16_v4i32 : IntegerToFPVector<0b00, 0b010, "scvtf",
5458+
// FPR128, FPR128,
5459+
// fixedpoint_recip_v4f16_v4i32,
5460+
// 1, 0b00, 0b10, [HasFullFP16]>;
5461+
5462+
// defm SCVTFv8f16_v8i32 : IntegerToFPVector<0b00, 0b010, "scvtf",
5463+
// FPR128, FPR128,
5464+
// fixedpoint_recip_v8f16_v8i32,
5465+
// 1, 0b00, 0b10, [HasFullFP16]>;
5466+
5467+
defm SCVTFv2f32_v2i32
5468+
: IntegerToFPVector<0b00, 0b010, "scvtf", FPR64, FPR64,
5469+
fixedpoint_recip_v2f32_v2i32, 0, 0b01, 0b10, []>;
5470+
5471+
defm SCVTFv4f32_v4i32
5472+
: IntegerToFPVector<0b00, 0b010, "scvtf", FPR128, FPR128,
5473+
fixedpoint_recip_v4f32_v4i32, 1, 0b01, 0b10, []>;
5474+
5475+
// SCVTF (Signed Convert To Floating-Point) from Vector 64-bit Integer (vNi64)
5476+
// defm SCVTFv2f16_v2i64 : IntegerToFPVector<0b00, 0b010, "scvtf",
5477+
// FPR128, FPR128,
5478+
// fixedpoint_recip_v2f16_v2i64,
5479+
// 1, 0b00, 0b11, [HasFullFP16]>;
5480+
5481+
// defm SCVTFv2f32_v2i64 : IntegerToFPVector<0b00, 0b010, "scvtf",
5482+
// FPR128, FPR128,
5483+
// fixedpoint_recip_v2f32_v2i64,
5484+
// 1, 0b01, 0b11, []>;
5485+
5486+
defm SCVTFv2f64_v2i64
5487+
: IntegerToFPVector<0b00, 0b010, "scvtf", FPR128, FPR128,
5488+
fixedpoint_recip_v2f64_v2i64, 1, 0b10, 0b11, []>;
5489+
5490+
// def : Pat<
5491+
// (fmul (sint_to_fp (v2i32 V64:$Rn)),
5492+
// fixedpoint_recip_v2f32_v2i32:$scale),
5493+
// (SCVTFv2f16_v2i32_V V64:$Rn, fixedpoint_recip_v2f32_v2i32:$scale)
5494+
// >;
5495+
5496+
// def : Pat<
5497+
// (fmul (sint_to_fp (v4i32 FPR128:$Rn)),
5498+
// fixedpoint_recip_v4f16_v4i32:$scale),
5499+
// (SCVTFv4f16_v4i32_V FPR128:$Rn, fixedpoint_recip_v4f16_v4i32:$scale)
5500+
// >;
5501+
5502+
// def : Pat<
5503+
// (fmul (sint_to_fp (v8i32 FPR128:$Rn)),
5504+
// fixedpoint_recip_v8f16_v8i32:$scale),
5505+
// (SCVTFv8f16_v8i32_V FPR128:$Rn, fixedpoint_recip_v8f16_v8i32:$scale)
5506+
// >;
5507+
5508+
def : Pat<(fmul(sint_to_fp(v2i32 V64:$Rn)),
5509+
fixedpoint_recip_v2f32_v2i32:$scale),
5510+
(SCVTFv2f32_v2i32_V V64:$Rn, fixedpoint_recip_v2f32_v2i32:$scale)>;
5511+
5512+
def : Pat<(fmul(sint_to_fp(v4i32 FPR128:$Rn)),
5513+
fixedpoint_recip_v4f32_v4i32:$scale),
5514+
(SCVTFv4f32_v4i32_V FPR128:$Rn, fixedpoint_recip_v4f32_v4i32:$scale)>;
5515+
5516+
// def : Pat<
5517+
// (fmul (sint_to_fp (v2i64 FPR128:$Rn)),
5518+
// fixedpoint_recip_v2f16_v2i64:$scale),
5519+
// (SCVTFv2f16_v2i64_V FPR128:$Rn, fixedpoint_recip_v2f16_v2i64:$scale)
5520+
// >;
5521+
5522+
// def : Pat<
5523+
// (fmul (sint_to_fp (v2i64 FPR128:$Rn)),
5524+
// fixedpoint_recip_v2f32_v2i64:$scale),
5525+
// (SCVTFv2f32_v2i64_V FPR128:$Rn, fixedpoint_recip_v2f32_v2i64:$scale)
5526+
// >;
5527+
5528+
def : Pat<(fmul(sint_to_fp(v2i64 FPR128:$Rn)),
5529+
fixedpoint_recip_v2f64_v2i64:$scale),
5530+
(SCVTFv2f64_v2i64_V FPR128:$Rn, fixedpoint_recip_v2f64_v2i64:$scale)>;
5531+
54105532
multiclass IntegerToFP<bits<2> rmode, bits<3> opcode, string asm, SDPatternOperator node> {
54115533
// Unscaled
54125534
def UWHri: BaseIntegerToFPUnscaled<rmode, opcode, GPR32, FPR16, f16, asm, node> {

0 commit comments

Comments
 (0)