Skip to content

Commit 73bed64

Browse files
authored
[AArch64] Improve lowering for scalable masked deinterleaving loads (#154338)
For IR like this: %mask = ... @llvm.vector.interleave2(<vscale x 16 x i1> %a, <vscale x 16 x i1> %a) %vec = ... @llvm.masked.load(..., <vscale x 32 x i1> %mask, ...) %dvec = ... @llvm.vector.deinterleave2(<vscale x 32 x i8> %vec) where we're deinterleaving a wide masked load of the supported type and with an interleaved mask we can lower this directly to a ld2b instruction. Similarly we can also support other variants of ld2 and ld4. This PR adds a DAG combine to spot such patterns and lower to ld2X or ld4X variants accordingly, whilst being careful to ensure the masked load is only used by the deinterleave intrinsic.
1 parent 349523e commit 73bed64

File tree

4 files changed

+871
-0
lines changed

4 files changed

+871
-0
lines changed

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3338,6 +3338,14 @@ namespace ISD {
33383338
return St && St->getAddressingMode() == ISD::UNINDEXED;
33393339
}
33403340

3341+
/// Returns true if the specified node is a non-extending and unindexed
3342+
/// masked load.
3343+
inline bool isNormalMaskedLoad(const SDNode *N) {
3344+
auto *Ld = dyn_cast<MaskedLoadSDNode>(N);
3345+
return Ld && Ld->getExtensionType() == ISD::NON_EXTLOAD &&
3346+
Ld->getAddressingMode() == ISD::UNINDEXED;
3347+
}
3348+
33413349
/// Attempt to match a unary predicate against a scalar/splat constant or
33423350
/// every element of a constant BUILD_VECTOR.
33433351
/// If AllowUndef is true, then UNDEF elements will pass nullptr to Match.

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
11791179
setTargetDAGCombine(ISD::SCALAR_TO_VECTOR);
11801180

11811181
setTargetDAGCombine(ISD::SHL);
1182+
setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE);
11821183

11831184
// In case of strict alignment, avoid an excessive number of byte wide stores.
11841185
MaxStoresPerMemsetOptSize = 8;
@@ -27207,6 +27208,115 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
2720727208
return NVCAST;
2720827209
}
2720927210

27211+
static SDValue performVectorDeinterleaveCombine(
27212+
SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
27213+
if (!DCI.isBeforeLegalize())
27214+
return SDValue();
27215+
27216+
unsigned NumParts = N->getNumOperands();
27217+
if (NumParts != 2 && NumParts != 4)
27218+
return SDValue();
27219+
27220+
EVT SubVecTy = N->getValueType(0);
27221+
27222+
// At the moment we're unlikely to see a fixed-width vector deinterleave as
27223+
// we usually generate shuffles instead.
27224+
unsigned MinNumElements = SubVecTy.getVectorMinNumElements();
27225+
if (!SubVecTy.isScalableVector() ||
27226+
SubVecTy.getSizeInBits().getKnownMinValue() != 128 ||
27227+
!DAG.getTargetLoweringInfo().isTypeLegal(SubVecTy))
27228+
return SDValue();
27229+
27230+
// Make sure each input operand is the correct extract_subvector of the same
27231+
// wider vector.
27232+
SDValue Op0 = N->getOperand(0);
27233+
for (unsigned I = 0; I < NumParts; I++) {
27234+
SDValue OpI = N->getOperand(I);
27235+
if (OpI->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
27236+
OpI->getOperand(0) != Op0->getOperand(0))
27237+
return SDValue();
27238+
if (OpI->getConstantOperandVal(1) != (I * MinNumElements))
27239+
return SDValue();
27240+
}
27241+
27242+
// Normal loads are currently already handled by the InterleavedAccessPass so
27243+
// we don't expect to see them here. Bail out if the masked load has an
27244+
// unexpected number of uses, since we want to avoid a situation where we have
27245+
// both deinterleaving loads and normal loads in the same block. Also, discard
27246+
// masked loads that are extending, indexed, have an unexpected offset or have
27247+
// an unsupported passthru value until we find a valid use case.
27248+
auto MaskedLoad = dyn_cast<MaskedLoadSDNode>(Op0->getOperand(0));
27249+
if (!MaskedLoad || !MaskedLoad->hasNUsesOfValue(NumParts, 0) ||
27250+
!MaskedLoad->isSimple() || !ISD::isNormalMaskedLoad(MaskedLoad) ||
27251+
!MaskedLoad->getOffset().isUndef() ||
27252+
(!MaskedLoad->getPassThru()->isUndef() &&
27253+
!isZerosVector(MaskedLoad->getPassThru().getNode())))
27254+
return SDValue();
27255+
27256+
// Now prove that the mask is an interleave of identical masks.
27257+
SDValue Mask = MaskedLoad->getMask();
27258+
if (Mask->getOpcode() != ISD::SPLAT_VECTOR &&
27259+
Mask->getOpcode() != ISD::CONCAT_VECTORS)
27260+
return SDValue();
27261+
27262+
SDValue NarrowMask;
27263+
SDLoc DL(N);
27264+
if (Mask->getOpcode() == ISD::CONCAT_VECTORS) {
27265+
if (Mask->getNumOperands() != NumParts)
27266+
return SDValue();
27267+
27268+
// We should be concatenating each sequential result from a
27269+
// VECTOR_INTERLEAVE.
27270+
SDNode *InterleaveOp = Mask->getOperand(0).getNode();
27271+
if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
27272+
InterleaveOp->getNumOperands() != NumParts)
27273+
return SDValue();
27274+
27275+
for (unsigned I = 0; I < NumParts; I++) {
27276+
if (Mask.getOperand(I) != SDValue(InterleaveOp, I))
27277+
return SDValue();
27278+
}
27279+
27280+
// Make sure the inputs to the vector interleave are identical.
27281+
if (!llvm::all_equal(InterleaveOp->op_values()))
27282+
return SDValue();
27283+
27284+
NarrowMask = InterleaveOp->getOperand(0);
27285+
} else { // ISD::SPLAT_VECTOR
27286+
ElementCount EC = Mask.getValueType().getVectorElementCount();
27287+
assert(EC.isKnownMultipleOf(NumParts) &&
27288+
"Expected element count divisible by number of parts");
27289+
EC = EC.divideCoefficientBy(NumParts);
27290+
NarrowMask =
27291+
DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
27292+
Mask->getOperand(0));
27293+
}
27294+
27295+
const Intrinsic::ID IID = NumParts == 2 ? Intrinsic::aarch64_sve_ld2_sret
27296+
: Intrinsic::aarch64_sve_ld4_sret;
27297+
SDValue NewLdOps[] = {MaskedLoad->getChain(),
27298+
DAG.getConstant(IID, DL, MVT::i32), NarrowMask,
27299+
MaskedLoad->getBasePtr()};
27300+
SDValue Res;
27301+
if (NumParts == 2)
27302+
Res = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL,
27303+
{SubVecTy, SubVecTy, MVT::Other}, NewLdOps);
27304+
else
27305+
Res = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL,
27306+
{SubVecTy, SubVecTy, SubVecTy, SubVecTy, MVT::Other},
27307+
NewLdOps);
27308+
27309+
// We can now generate a structured load!
27310+
SmallVector<SDValue, 4> ResOps(NumParts);
27311+
for (unsigned Idx = 0; Idx < NumParts; Idx++)
27312+
ResOps[Idx] = SDValue(Res.getNode(), Idx);
27313+
27314+
// Replace uses of the original chain result with the new chain result.
27315+
DAG.ReplaceAllUsesOfValueWith(SDValue(MaskedLoad, 1),
27316+
SDValue(Res.getNode(), NumParts));
27317+
return DCI.CombineTo(N, ResOps, false);
27318+
}
27319+
2721027320
/// If the operand is a bitwise AND with a constant RHS, and the shift has a
2721127321
/// constant RHS and is the only use, we can pull it out of the shift, i.e.
2721227322
///
@@ -27275,6 +27385,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2727527385
default:
2727627386
LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
2727727387
break;
27388+
case ISD::VECTOR_DEINTERLEAVE:
27389+
return performVectorDeinterleaveCombine(N, DCI, DAG);
2727827390
case ISD::VECREDUCE_AND:
2727927391
case ISD::VECREDUCE_OR:
2728027392
case ISD::VECREDUCE_XOR:

0 commit comments

Comments
 (0)