Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,14 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
bool SelectCVTFixedPosRecipOperand(SDValue N, SDValue &FixedPos,
unsigned Width);

template <unsigned FloatWidth>
bool SelectCVTFixedPosRecipOperandVec(SDValue N, SDValue &FixedPos) {
return SelectCVTFixedPosRecipOperandVec(N, FixedPos, FloatWidth);
}

bool SelectCVTFixedPosRecipOperandVec(SDValue N, SDValue &FixedPos,
unsigned Width);

bool SelectCMP_SWAP(SDNode *N);

bool SelectSVEAddSubImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift);
Expand Down Expand Up @@ -3952,6 +3960,129 @@ static bool checkCVTFixedPointOperandWithFBits(SelectionDAG *CurDAG, SDValue N,
return true;
}

static bool checkCVTFixedPointOperandWithFBitsForVectors(SelectionDAG *CurDAG,
SDValue N,
SDValue &FixedPos,
unsigned FloatWidth,
bool IsReciprocal) {

if (N->getNumOperands() < 1)
return false;

SDValue ImmediateNode = N.getOperand(0);
if (N.getOpcode() == ISD::BITCAST || N.getOpcode() == AArch64ISD::NVCAST) {
// This could have been a bitcast to a scalar
if (!ImmediateNode.getValueType().isVector())
return false;
}

if (!(ImmediateNode.getOpcode() == AArch64ISD::DUP ||
ImmediateNode.getOpcode() == AArch64ISD::MOVIshift ||
ImmediateNode.getOpcode() == ISD::BUILD_VECTOR ||
ImmediateNode.getOpcode() == ISD::Constant ||
ImmediateNode.getOpcode() == ISD::SPLAT_VECTOR)) {
return false;
}

if (ImmediateNode.getOpcode() != ISD::Constant) {
auto *C = dyn_cast<ConstantSDNode>(ImmediateNode.getOperand(0));
if (!C)
return false;
}

if (ImmediateNode.getOpcode() == ISD::BUILD_VECTOR) {
// For BUILD_VECTOR, we must explicitly check if it's a constant splat.
BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(ImmediateNode.getNode());
APInt SplatValue;
APInt SplatUndef;
unsigned SplatBitSize;
bool HasAnyUndefs;
if (!BVN->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
HasAnyUndefs)) {
return false;
}
}

APInt Imm;
bool IsIntConstant = false;
if (ImmediateNode.getOpcode() == AArch64ISD::MOVIshift) {
EVT NodeVT = N.getValueType();
Imm = APInt(NodeVT.getScalarSizeInBits(),
ImmediateNode.getConstantOperandVal(0)
<< ImmediateNode.getConstantOperandVal(1));
IsIntConstant = true;
} else if (ImmediateNode.getOpcode() == ISD::Constant) {
auto *C = dyn_cast<ConstantSDNode>(ImmediateNode);
if (!C)
return false;
uint8_t EncodedU8 = static_cast<uint8_t>(C->getZExtValue());
uint64_t DecodedBits = AArch64_AM::decodeAdvSIMDModImmType11(EncodedU8);

unsigned BitWidth = N.getValueType().getVectorElementType().getSizeInBits();
uint64_t Mask = (BitWidth == 64) ? ~0ULL : ((1ULL << BitWidth) - 1);
uint64_t MaskedBits = DecodedBits & Mask;

Imm = APInt(BitWidth, MaskedBits);
IsIntConstant = true;
} else if (auto *CI = dyn_cast<ConstantSDNode>(ImmediateNode.getOperand(0))) {
Imm = CI->getAPIntValue();
IsIntConstant = true;
}

APFloat FVal(0.0);
// --- Extract the actual constant value ---
if (IsIntConstant) {
// Scalar source is an integer constant; interpret its bits as
// floating-point.
EVT FloatEltVT = N.getValueType().getVectorElementType();

if (FloatEltVT == MVT::f32) {
FVal = APFloat(APFloat::IEEEsingle(), Imm);
} else if (FloatEltVT == MVT::f64) {
FVal = APFloat(APFloat::IEEEdouble(), Imm);
} else if (FloatEltVT == MVT::f16) {
FVal = APFloat(APFloat::IEEEhalf(), Imm);
} else {
// Unsupported floating-point element type.
return false;
}
} else {
// ScalarSourceNode is not a recognized constant type.
return false;
}

// Handle reciprocal case.
if (IsReciprocal) {
if (!FVal.getExactInverse(&FVal))
// Not an exact reciprocal, or reciprocal not a power of 2.
return false;
}

bool IsExact;
unsigned TargetIntBits =
N.getValueType().getVectorElementType().getSizeInBits();
APSInt IntVal(
TargetIntBits + 1,
true); // Use TargetIntBits + 1 for sufficient bits for conversion

FVal.convertToInteger(IntVal, APFloat::rmTowardZero, &IsExact);

if (!IsExact || !IntVal.isPowerOf2())
return false;

unsigned FBits = IntVal.logBase2();
// FBits must be non-zero (implies actual scaling) and within the range
// supported by the instruction (typically 1 to 64 for AArch64 FCVTZS/FCVTZU).
// FloatWidth should ideally be the width of the *integer elements* in the
// vector (16, 32, 64).
if (FBits == 0 || FBits > FloatWidth)
return false;

// Set FixedPos to the extracted FBits as an i32 constant SDValue.
FixedPos = CurDAG->getTargetConstant(FBits, SDLoc(N), MVT::i32);
return true;
}

bool AArch64DAGToDAGISel::SelectCVTFixedPosOperand(SDValue N, SDValue &FixedPos,
unsigned RegWidth) {
return checkCVTFixedPointOperandWithFBits(CurDAG, N, FixedPos, RegWidth,
Expand All @@ -3965,6 +4096,12 @@ bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperand(SDValue N,
true);
}

bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperandVec(
SDValue N, SDValue &FixedPos, unsigned FloatWidth) {
return checkCVTFixedPointOperandWithFBitsForVectors(CurDAG, N, FixedPos,
FloatWidth, true);
}

// Inspects a register string of the form o0:op1:CRn:CRm:op2 gets the fields
// of the string and obtains the integer values from them and combines these
// into a single value to be used in the MRS/MSR instruction.
Expand Down
52 changes: 52 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -8473,6 +8473,58 @@ def : Pat<(v8f16 (sint_to_fp (v8i16 (AArch64vashr_exact v8i16:$Vn, i32:$shift)))
(SCVTFv8i16_shift $Vn, vecshiftR16:$shift)>;
}

// Select fmul(sitofp(x), C) where C is a constant reciprocal of a power of two.
// For both scalar and vector inputs, if we have sitofp(X) * C (where C is
// 1/2^N), this can be optimized to scvtf(X, 2^N).
class fixedpoint_recip_vec_i16<ValueType FloatVT>
: ComplexPattern<FloatVT, 1, "SelectCVTFixedPosRecipOperandVec<16>", []>;
class fixedpoint_recip_vec_i32<ValueType FloatVT>
: ComplexPattern<FloatVT, 1, "SelectCVTFixedPosRecipOperandVec<32>", []>;
class fixedpoint_recip_vec_i64<ValueType FloatVT>
: ComplexPattern<FloatVT, 1, "SelectCVTFixedPosRecipOperandVec<64>", []>;
def fixedpoint_recip_vec_xform : SDNodeXForm<timm, [{
// Suppress the unused variable warning by explicitly using N.
// The actual value needed for the pattern is already in V.
(void)N;
return V;
}]>;

def fixedpoint_recip_v2f32_v2i32 : fixedpoint_recip_vec_i32<v2f32>;
def fixedpoint_recip_v4f32_v4i32 : fixedpoint_recip_vec_i32<v4f32>;
def fixedpoint_recip_v2f64_v2i64 : fixedpoint_recip_vec_i64<v2f64>;

def fixedpoint_recip_v4f16_v4i16 : fixedpoint_recip_vec_i16<v4f16>;
def fixedpoint_recip_v8f16_v8i16 : fixedpoint_recip_vec_i16<v8f16>;

let Predicates = [HasNEON] in {
def : Pat<(v2f32(fmul(sint_to_fp(v2i32 V64:$Rn)),
fixedpoint_recip_v2f32_v2i32:$scale)),
(v2f32(SCVTFv2i32_shift(v2i32 V64:$Rn),
(fixedpoint_recip_vec_xform fixedpoint_recip_v2f32_v2i32:$scale)))>;

def : Pat<(v4f32(fmul(sint_to_fp(v4i32 FPR128:$Rn)),
fixedpoint_recip_v4f32_v4i32:$scale)),
(v4f32(SCVTFv4i32_shift(v4i32 FPR128:$Rn),
(fixedpoint_recip_vec_xform fixedpoint_recip_v4f32_v4i32:$scale)))>;

def : Pat<(v2f64(fmul(sint_to_fp(v2i64 FPR128:$Rn)),
fixedpoint_recip_v2f64_v2i64:$scale)),
(v2f64(SCVTFv2i64_shift(v2i64 FPR128:$Rn),
(fixedpoint_recip_vec_xform fixedpoint_recip_v2f64_v2i64:$scale)))>;
}

let Predicates = [HasNEON, HasFullFP16] in {
def : Pat<(v4f16(fmul(sint_to_fp(v4i16 V64:$Rn)),
fixedpoint_recip_v4f16_v4i16:$scale)),
(v4f16(SCVTFv4i16_shift(v4i16 V64:$Rn),
(fixedpoint_recip_vec_xform fixedpoint_recip_v4f16_v4i16:$scale)))>;

def : Pat<(v8f16(fmul(sint_to_fp(v8i16 FPR128:$Rn)),
fixedpoint_recip_v8f16_v8i16:$scale)),
(v8f16(SCVTFv8i16_shift(v8i16 FPR128:$Rn),
(fixedpoint_recip_vec_xform fixedpoint_recip_v8f16_v8i16:$scale)))>;
}

// X << 1 ==> X + X
class SHLToADDPat<ValueType ty, RegisterClass regtype>
: Pat<(ty (AArch64vshl (ty regtype:$Rn), (i32 1))),
Expand Down
Loading
Loading