Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23239,6 +23239,99 @@ static SDValue performZExtUZPCombine(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::AND, DL, VT, BC, DAG.getConstant(Mask, DL, VT));
}

// Helper function to optimize small vector load + extension patterns.
// These patterns would otherwise be scalarized into inefficient sequences.
static SDValue performSmallVectorLoadExtCombine(SDNode *N, SelectionDAG &DAG) {
// Don't optimize if NEON is not available. Without NEON, the backend
// will need to scalarize these operations anyway.
const AArch64Subtarget &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
if (!Subtarget.isNeonAvailable())
return SDValue();
// Don't optimize if SVE is being used for fixed-length vectors, because it
// has native support for these patterns.
if (Subtarget.useSVEForFixedLengthVectors())
return SDValue();

unsigned Opcode = N->getOpcode();
if (Opcode != ISD::ZERO_EXTEND && Opcode != ISD::SIGN_EXTEND &&
Opcode != ISD::ANY_EXTEND)
return SDValue();

SDValue Op = N->getOperand(0);
if (Op.getOpcode() != ISD::LOAD)
return SDValue();
LoadSDNode *LD = cast<LoadSDNode>(Op);
if (LD->getExtensionType() != ISD::NON_EXTLOAD || !LD->hasOneUse() ||
LD->isVolatile())
return SDValue();

EVT MemVT = LD->getMemoryVT();
EVT ResVT = N->getValueType(0);
// Check if this is a small vector pattern we want to optimize.
if (MemVT != MVT::v2i8 && MemVT != MVT::v2i16)
return SDValue();

unsigned NumElts = MemVT.getVectorNumElements();
unsigned SrcEltBits = MemVT.getScalarSizeInBits();
unsigned DstEltBits = ResVT.getScalarSizeInBits();
unsigned LoadBits = NumElts * SrcEltBits;

// Check alignment: the optimization loads a larger scalar, which may be
// unaligned, compared to what the original load will be legalized into.
Align Alignment = LD->getAlign();
if (Subtarget.requiresStrictAlign() && Alignment < LoadBits)
return SDValue();

// The transformation strategy:
// 1. Load the memory as a large scalar and turn it into a 64-bit vector.
// 2. Bitcast to a narrow type (v8i8 or v4i16) that has efficient NEON extend.
// 3. Extend using ushll/sshll, extract subvector, repeat as needed.

// For ANY_EXTEND, we can choose either sign or zero extend - zero is
// typically cheaper.
if (Opcode == ISD::ANY_EXTEND)
Opcode = ISD::ZERO_EXTEND;

SDLoc DL(N);
SDValue Chain = LD->getChain();
SDValue BasePtr = LD->getBasePtr();
const MachinePointerInfo &PtrInfo = LD->getPointerInfo();
MVT LoadTy = MVT::getIntegerVT(LoadBits);
SDValue Load = DAG.getLoad(LoadTy, DL, Chain, BasePtr, PtrInfo, Alignment);

// SCALAR_TO_VECTOR needs to create a 64-bit vector for NEON instructions.
// The scalar load is inserted into the lower bits of a 64-bit register.
// We determine the appropriate 64-bit vector type based on load size,
// then bitcast to v8i8 or v4i16 for efficient ushll/sshll extends.
MVT ScalarVecVT = MVT::getVectorVT(LoadTy, 64 / LoadBits);
MVT NarrowVT = MVT::getVectorVT(MemVT.getVectorElementType().getSimpleVT(),
64 / MemVT.getScalarSizeInBits());

SDValue Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ScalarVecVT, Load);
Vec = DAG.getNode(ISD::BITCAST, DL, NarrowVT, Vec);
// Extend iteratively: each extend doubles the element size.
// We extend the full 64-bit vector to leverage NEON ushll/sshll instructions.
while (Vec.getScalarValueSizeInBits() < DstEltBits) {
MVT CurVT = Vec.getSimpleValueType();
unsigned NextBits = CurVT.getScalarSizeInBits() * 2;
MVT WideVT = MVT::getVectorVT(MVT::getIntegerVT(NextBits),
CurVT.getVectorNumElements());
Vec = DAG.getNode(Opcode, DL, WideVT, Vec);

// Extract only when: excess elements + still wide + done extending.
bool HasExcess = WideVT.getVectorNumElements() > NumElts;
bool StaysWide = WideVT.getSizeInBits() >= 64;
bool IsDone = NextBits >= DstEltBits;
if (HasExcess && StaysWide && IsDone) {
MVT ExtractVT = MVT::getVectorVT(WideVT.getScalarType(), NumElts);
Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT, Vec,
DAG.getConstant(0, DL, MVT::i64));
}
}

return DAG.getMergeValues({Vec, Load.getValue(1)}, DL);
}

static SDValue performExtendCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
Expand Down Expand Up @@ -23288,6 +23381,12 @@ static SDValue performExtendCombine(SDNode *N,
NewAnyExtend);
}

// Try to optimize small vector load + extension patterns

// Try to optimize small vector load + extension patterns
if (SDValue Result = performSmallVectorLoadExtCombine(N, DAG))
return Result;

return SDValue();
}

Expand Down
Loading
Loading