@@ -24689,6 +24689,105 @@ static SDValue performSTORECombine(SDNode *N,
2468924689 return SDValue();
2469024690}
2469124691
24692+ static bool
24693+ isSequentialConcatOfVectorInterleave(SDNode *N, SmallVectorImpl<SDValue> &Ops) {
24694+ if (N->getOpcode() != ISD::CONCAT_VECTORS)
24695+ return false;
24696+
24697+ unsigned NumParts = N->getNumOperands();
24698+
24699+ // We should be concatenating each sequential result from a
24700+ // VECTOR_INTERLEAVE.
24701+ SDNode *InterleaveOp = N->getOperand(0).getNode();
24702+ if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
24703+ InterleaveOp->getNumOperands() != NumParts)
24704+ return false;
24705+
24706+ for (unsigned I = 0; I < NumParts; I++)
24707+ if (N->getOperand(I) != SDValue(InterleaveOp, I))
24708+ return false;
24709+
24710+ Ops.append(InterleaveOp->op_begin(), InterleaveOp->op_end());
24711+ return true;
24712+ }
24713+
24714+ static SDValue getNarrowMaskForInterleavedOps(SelectionDAG &DAG, SDLoc &DL,
24715+ SDValue WideMask,
24716+ unsigned RequiredNumParts) {
24717+ if (WideMask->getOpcode() == ISD::CONCAT_VECTORS) {
24718+ SmallVector<SDValue, 4> MaskInterleaveOps;
24719+ if (!isSequentialConcatOfVectorInterleave(WideMask.getNode(),
24720+ MaskInterleaveOps))
24721+ return SDValue();
24722+
24723+ if (MaskInterleaveOps.size() != RequiredNumParts)
24724+ return SDValue();
24725+
24726+ // Make sure the inputs to the vector interleave are identical.
24727+ if (!llvm::all_equal(MaskInterleaveOps))
24728+ return SDValue();
24729+
24730+ return MaskInterleaveOps[0];
24731+ }
24732+
24733+ if (WideMask->getOpcode() != ISD::SPLAT_VECTOR)
24734+ return SDValue();
24735+
24736+ ElementCount EC = WideMask.getValueType().getVectorElementCount();
24737+ assert(EC.isKnownMultipleOf(RequiredNumParts) &&
24738+ "Expected element count divisible by number of parts");
24739+ EC = EC.divideCoefficientBy(RequiredNumParts);
24740+ return DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
24741+ WideMask->getOperand(0));
24742+ }
24743+
24744+ static SDValue performInterleavedMaskedStoreCombine(
24745+ SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
24746+ if (!DCI.isBeforeLegalize())
24747+ return SDValue();
24748+
24749+ MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
24750+ SDValue WideValue = MST->getValue();
24751+
24752+ // Bail out if the stored value has an unexpected number of uses, since we'll
24753+ // have to perform manual interleaving and may as well just use normal masked
24754+ // stores. Also, discard masked stores that are truncating or indexed.
24755+ if (!WideValue.hasOneUse() || !ISD::isNormalMaskedStore(MST) ||
24756+ !MST->isSimple() || !MST->getOffset().isUndef())
24757+ return SDValue();
24758+
24759+ SmallVector<SDValue, 4> ValueInterleaveOps;
24760+ if (!isSequentialConcatOfVectorInterleave(WideValue.getNode(),
24761+ ValueInterleaveOps))
24762+ return SDValue();
24763+
24764+ unsigned NumParts = ValueInterleaveOps.size();
24765+ if (NumParts != 2 && NumParts != 4)
24766+ return SDValue();
24767+
24768+ // At the moment we're unlikely to see a fixed-width vector interleave as
24769+ // we usually generate shuffles instead.
24770+ EVT SubVecTy = ValueInterleaveOps[0].getValueType();
24771+ if (!SubVecTy.isScalableVT() ||
24772+ SubVecTy.getSizeInBits().getKnownMinValue() != 128 ||
24773+ !DAG.getTargetLoweringInfo().isTypeLegal(SubVecTy))
24774+ return SDValue();
24775+
24776+ SDLoc DL(N);
24777+ SDValue NarrowMask =
24778+ getNarrowMaskForInterleavedOps(DAG, DL, MST->getMask(), NumParts);
24779+ if (!NarrowMask)
24780+ return SDValue();
24781+
24782+ const Intrinsic::ID IID =
24783+ NumParts == 2 ? Intrinsic::aarch64_sve_st2 : Intrinsic::aarch64_sve_st4;
24784+ SmallVector<SDValue, 8> NewStOps;
24785+ NewStOps.append({MST->getChain(), DAG.getConstant(IID, DL, MVT::i32)});
24786+ NewStOps.append(ValueInterleaveOps);
24787+ NewStOps.append({NarrowMask, MST->getBasePtr()});
24788+ return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, NewStOps);
24789+ }
24790+
2469224791static SDValue performMSTORECombine(SDNode *N,
2469324792 TargetLowering::DAGCombinerInfo &DCI,
2469424793 SelectionDAG &DAG,
@@ -24698,6 +24797,9 @@ static SDValue performMSTORECombine(SDNode *N,
2469824797 SDValue Mask = MST->getMask();
2469924798 SDLoc DL(N);
2470024799
24800+ if (SDValue Res = performInterleavedMaskedStoreCombine(N, DCI, DAG))
24801+ return Res;
24802+
2470124803 // If this is a UZP1 followed by a masked store, fold this into a masked
2470224804 // truncating store. We can do this even if this is already a masked
2470324805 // truncstore.
@@ -27331,43 +27433,11 @@ static SDValue performVectorDeinterleaveCombine(
2733127433 return SDValue();
2733227434
2733327435 // Now prove that the mask is an interleave of identical masks.
27334- SDValue Mask = MaskedLoad->getMask();
27335- if (Mask->getOpcode() != ISD::SPLAT_VECTOR &&
27336- Mask->getOpcode() != ISD::CONCAT_VECTORS)
27337- return SDValue();
27338-
27339- SDValue NarrowMask;
2734027436 SDLoc DL(N);
27341- if (Mask->getOpcode() == ISD::CONCAT_VECTORS) {
27342- if (Mask->getNumOperands() != NumParts)
27343- return SDValue();
27344-
27345- // We should be concatenating each sequential result from a
27346- // VECTOR_INTERLEAVE.
27347- SDNode *InterleaveOp = Mask->getOperand(0).getNode();
27348- if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
27349- InterleaveOp->getNumOperands() != NumParts)
27350- return SDValue();
27351-
27352- for (unsigned I = 0; I < NumParts; I++) {
27353- if (Mask.getOperand(I) != SDValue(InterleaveOp, I))
27354- return SDValue();
27355- }
27356-
27357- // Make sure the inputs to the vector interleave are identical.
27358- if (!llvm::all_equal(InterleaveOp->op_values()))
27359- return SDValue();
27360-
27361- NarrowMask = InterleaveOp->getOperand(0);
27362- } else { // ISD::SPLAT_VECTOR
27363- ElementCount EC = Mask.getValueType().getVectorElementCount();
27364- assert(EC.isKnownMultipleOf(NumParts) &&
27365- "Expected element count divisible by number of parts");
27366- EC = EC.divideCoefficientBy(NumParts);
27367- NarrowMask =
27368- DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
27369- Mask->getOperand(0));
27370- }
27437+ SDValue NarrowMask =
27438+ getNarrowMaskForInterleavedOps(DAG, DL, MaskedLoad->getMask(), NumParts);
27439+ if (!NarrowMask)
27440+ return SDValue();
2737127441
2737227442 const Intrinsic::ID IID = NumParts == 2 ? Intrinsic::aarch64_sve_ld2_sret
2737327443 : Intrinsic::aarch64_sve_ld4_sret;
0 commit comments