Skip to content

Commit dd0161f

Browse files
authored
[AArch64] Improve lowering for scalable masked interleaving stores (#156718)
Similar to #154338, this PR aims to support lowering of certain IR to SVE's st2 and st4 instructions. The typical IR scenario looks like: %mask = .. @llvm.vector.interleave2(<vscale x 16 x i1> %m, <vscale x 16 x i1> %m) %val = .. @llvm.vector.interleave2(<vscale x 16 x i8> %v1, <vscale x 16 x i8> %v2) .. @llvm.masked.store.nxv32i8.p0(<vscale x 32 x i8> %val, ..., <vscale x 32 x i1> %mask) where we're interleaving both the value and the mask being passed to the wide store. When the mask interleave parts are identical we can lower this to st2b. This PR adds a DAG combine for lowering this kind of IR pattern to st2X and st4X SVE instructions.
1 parent c71da7d commit dd0161f

File tree

4 files changed

+857
-36
lines changed

4 files changed

+857
-36
lines changed

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3346,6 +3346,14 @@ namespace ISD {
33463346
Ld->getAddressingMode() == ISD::UNINDEXED;
33473347
}
33483348

3349+
/// Returns true if the specified node is a non-extending and unindexed
3350+
/// masked store.
3351+
inline bool isNormalMaskedStore(const SDNode *N) {
3352+
auto *St = dyn_cast<MaskedStoreSDNode>(N);
3353+
return St && !St->isTruncatingStore() &&
3354+
St->getAddressingMode() == ISD::UNINDEXED;
3355+
}
3356+
33493357
/// Attempt to match a unary predicate against a scalar/splat constant or
33503358
/// every element of a constant BUILD_VECTOR.
33513359
/// If AllowUndef is true, then UNDEF elements will pass nullptr to Match.

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 106 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24689,6 +24689,105 @@ static SDValue performSTORECombine(SDNode *N,
2468924689
return SDValue();
2469024690
}
2469124691

24692+
static bool
24693+
isSequentialConcatOfVectorInterleave(SDNode *N, SmallVectorImpl<SDValue> &Ops) {
24694+
if (N->getOpcode() != ISD::CONCAT_VECTORS)
24695+
return false;
24696+
24697+
unsigned NumParts = N->getNumOperands();
24698+
24699+
// We should be concatenating each sequential result from a
24700+
// VECTOR_INTERLEAVE.
24701+
SDNode *InterleaveOp = N->getOperand(0).getNode();
24702+
if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
24703+
InterleaveOp->getNumOperands() != NumParts)
24704+
return false;
24705+
24706+
for (unsigned I = 0; I < NumParts; I++)
24707+
if (N->getOperand(I) != SDValue(InterleaveOp, I))
24708+
return false;
24709+
24710+
Ops.append(InterleaveOp->op_begin(), InterleaveOp->op_end());
24711+
return true;
24712+
}
24713+
24714+
static SDValue getNarrowMaskForInterleavedOps(SelectionDAG &DAG, SDLoc &DL,
24715+
SDValue WideMask,
24716+
unsigned RequiredNumParts) {
24717+
if (WideMask->getOpcode() == ISD::CONCAT_VECTORS) {
24718+
SmallVector<SDValue, 4> MaskInterleaveOps;
24719+
if (!isSequentialConcatOfVectorInterleave(WideMask.getNode(),
24720+
MaskInterleaveOps))
24721+
return SDValue();
24722+
24723+
if (MaskInterleaveOps.size() != RequiredNumParts)
24724+
return SDValue();
24725+
24726+
// Make sure the inputs to the vector interleave are identical.
24727+
if (!llvm::all_equal(MaskInterleaveOps))
24728+
return SDValue();
24729+
24730+
return MaskInterleaveOps[0];
24731+
}
24732+
24733+
if (WideMask->getOpcode() != ISD::SPLAT_VECTOR)
24734+
return SDValue();
24735+
24736+
ElementCount EC = WideMask.getValueType().getVectorElementCount();
24737+
assert(EC.isKnownMultipleOf(RequiredNumParts) &&
24738+
"Expected element count divisible by number of parts");
24739+
EC = EC.divideCoefficientBy(RequiredNumParts);
24740+
return DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
24741+
WideMask->getOperand(0));
24742+
}
24743+
24744+
static SDValue performInterleavedMaskedStoreCombine(
24745+
SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
24746+
if (!DCI.isBeforeLegalize())
24747+
return SDValue();
24748+
24749+
MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
24750+
SDValue WideValue = MST->getValue();
24751+
24752+
// Bail out if the stored value has an unexpected number of uses, since we'll
24753+
// have to perform manual interleaving and may as well just use normal masked
24754+
// stores. Also, discard masked stores that are truncating or indexed.
24755+
if (!WideValue.hasOneUse() || !ISD::isNormalMaskedStore(MST) ||
24756+
!MST->isSimple() || !MST->getOffset().isUndef())
24757+
return SDValue();
24758+
24759+
SmallVector<SDValue, 4> ValueInterleaveOps;
24760+
if (!isSequentialConcatOfVectorInterleave(WideValue.getNode(),
24761+
ValueInterleaveOps))
24762+
return SDValue();
24763+
24764+
unsigned NumParts = ValueInterleaveOps.size();
24765+
if (NumParts != 2 && NumParts != 4)
24766+
return SDValue();
24767+
24768+
// At the moment we're unlikely to see a fixed-width vector interleave as
24769+
// we usually generate shuffles instead.
24770+
EVT SubVecTy = ValueInterleaveOps[0].getValueType();
24771+
if (!SubVecTy.isScalableVT() ||
24772+
SubVecTy.getSizeInBits().getKnownMinValue() != 128 ||
24773+
!DAG.getTargetLoweringInfo().isTypeLegal(SubVecTy))
24774+
return SDValue();
24775+
24776+
SDLoc DL(N);
24777+
SDValue NarrowMask =
24778+
getNarrowMaskForInterleavedOps(DAG, DL, MST->getMask(), NumParts);
24779+
if (!NarrowMask)
24780+
return SDValue();
24781+
24782+
const Intrinsic::ID IID =
24783+
NumParts == 2 ? Intrinsic::aarch64_sve_st2 : Intrinsic::aarch64_sve_st4;
24784+
SmallVector<SDValue, 8> NewStOps;
24785+
NewStOps.append({MST->getChain(), DAG.getConstant(IID, DL, MVT::i32)});
24786+
NewStOps.append(ValueInterleaveOps);
24787+
NewStOps.append({NarrowMask, MST->getBasePtr()});
24788+
return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, NewStOps);
24789+
}
24790+
2469224791
static SDValue performMSTORECombine(SDNode *N,
2469324792
TargetLowering::DAGCombinerInfo &DCI,
2469424793
SelectionDAG &DAG,
@@ -24698,6 +24797,9 @@ static SDValue performMSTORECombine(SDNode *N,
2469824797
SDValue Mask = MST->getMask();
2469924798
SDLoc DL(N);
2470024799

24800+
if (SDValue Res = performInterleavedMaskedStoreCombine(N, DCI, DAG))
24801+
return Res;
24802+
2470124803
// If this is a UZP1 followed by a masked store, fold this into a masked
2470224804
// truncating store. We can do this even if this is already a masked
2470324805
// truncstore.
@@ -27331,43 +27433,11 @@ static SDValue performVectorDeinterleaveCombine(
2733127433
return SDValue();
2733227434

2733327435
// Now prove that the mask is an interleave of identical masks.
27334-
SDValue Mask = MaskedLoad->getMask();
27335-
if (Mask->getOpcode() != ISD::SPLAT_VECTOR &&
27336-
Mask->getOpcode() != ISD::CONCAT_VECTORS)
27337-
return SDValue();
27338-
27339-
SDValue NarrowMask;
2734027436
SDLoc DL(N);
27341-
if (Mask->getOpcode() == ISD::CONCAT_VECTORS) {
27342-
if (Mask->getNumOperands() != NumParts)
27343-
return SDValue();
27344-
27345-
// We should be concatenating each sequential result from a
27346-
// VECTOR_INTERLEAVE.
27347-
SDNode *InterleaveOp = Mask->getOperand(0).getNode();
27348-
if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
27349-
InterleaveOp->getNumOperands() != NumParts)
27350-
return SDValue();
27351-
27352-
for (unsigned I = 0; I < NumParts; I++) {
27353-
if (Mask.getOperand(I) != SDValue(InterleaveOp, I))
27354-
return SDValue();
27355-
}
27356-
27357-
// Make sure the inputs to the vector interleave are identical.
27358-
if (!llvm::all_equal(InterleaveOp->op_values()))
27359-
return SDValue();
27360-
27361-
NarrowMask = InterleaveOp->getOperand(0);
27362-
} else { // ISD::SPLAT_VECTOR
27363-
ElementCount EC = Mask.getValueType().getVectorElementCount();
27364-
assert(EC.isKnownMultipleOf(NumParts) &&
27365-
"Expected element count divisible by number of parts");
27366-
EC = EC.divideCoefficientBy(NumParts);
27367-
NarrowMask =
27368-
DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
27369-
Mask->getOperand(0));
27370-
}
27437+
SDValue NarrowMask =
27438+
getNarrowMaskForInterleavedOps(DAG, DL, MaskedLoad->getMask(), NumParts);
27439+
if (!NarrowMask)
27440+
return SDValue();
2737127441

2737227442
const Intrinsic::ID IID = NumParts == 2 ? Intrinsic::aarch64_sve_ld2_sret
2737327443
: Intrinsic::aarch64_sve_ld4_sret;

0 commit comments

Comments
 (0)