@@ -24689,6 +24689,105 @@ static SDValue performSTORECombine(SDNode *N,
24689
24689
return SDValue();
24690
24690
}
24691
24691
24692
+ static bool
24693
+ isSequentialConcatOfVectorInterleave(SDNode *N, SmallVectorImpl<SDValue> &Ops) {
24694
+ if (N->getOpcode() != ISD::CONCAT_VECTORS)
24695
+ return false;
24696
+
24697
+ unsigned NumParts = N->getNumOperands();
24698
+
24699
+ // We should be concatenating each sequential result from a
24700
+ // VECTOR_INTERLEAVE.
24701
+ SDNode *InterleaveOp = N->getOperand(0).getNode();
24702
+ if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
24703
+ InterleaveOp->getNumOperands() != NumParts)
24704
+ return false;
24705
+
24706
+ for (unsigned I = 0; I < NumParts; I++)
24707
+ if (N->getOperand(I) != SDValue(InterleaveOp, I))
24708
+ return false;
24709
+
24710
+ Ops.append(InterleaveOp->op_begin(), InterleaveOp->op_end());
24711
+ return true;
24712
+ }
24713
+
24714
+ static SDValue getNarrowMaskForInterleavedOps(SelectionDAG &DAG, SDLoc &DL,
24715
+ SDValue WideMask,
24716
+ unsigned RequiredNumParts) {
24717
+ if (WideMask->getOpcode() == ISD::CONCAT_VECTORS) {
24718
+ SmallVector<SDValue, 4> MaskInterleaveOps;
24719
+ if (!isSequentialConcatOfVectorInterleave(WideMask.getNode(),
24720
+ MaskInterleaveOps))
24721
+ return SDValue();
24722
+
24723
+ if (MaskInterleaveOps.size() != RequiredNumParts)
24724
+ return SDValue();
24725
+
24726
+ // Make sure the inputs to the vector interleave are identical.
24727
+ if (!llvm::all_equal(MaskInterleaveOps))
24728
+ return SDValue();
24729
+
24730
+ return MaskInterleaveOps[0];
24731
+ }
24732
+
24733
+ if (WideMask->getOpcode() != ISD::SPLAT_VECTOR)
24734
+ return SDValue();
24735
+
24736
+ ElementCount EC = WideMask.getValueType().getVectorElementCount();
24737
+ assert(EC.isKnownMultipleOf(RequiredNumParts) &&
24738
+ "Expected element count divisible by number of parts");
24739
+ EC = EC.divideCoefficientBy(RequiredNumParts);
24740
+ return DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
24741
+ WideMask->getOperand(0));
24742
+ }
24743
+
24744
+ static SDValue performInterleavedMaskedStoreCombine(
24745
+ SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
24746
+ if (!DCI.isBeforeLegalize())
24747
+ return SDValue();
24748
+
24749
+ MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
24750
+ SDValue WideValue = MST->getValue();
24751
+
24752
+ // Bail out if the stored value has an unexpected number of uses, since we'll
24753
+ // have to perform manual interleaving and may as well just use normal masked
24754
+ // stores. Also, discard masked stores that are truncating or indexed.
24755
+ if (!WideValue.hasOneUse() || !ISD::isNormalMaskedStore(MST) ||
24756
+ !MST->isSimple() || !MST->getOffset().isUndef())
24757
+ return SDValue();
24758
+
24759
+ SmallVector<SDValue, 4> ValueInterleaveOps;
24760
+ if (!isSequentialConcatOfVectorInterleave(WideValue.getNode(),
24761
+ ValueInterleaveOps))
24762
+ return SDValue();
24763
+
24764
+ unsigned NumParts = ValueInterleaveOps.size();
24765
+ if (NumParts != 2 && NumParts != 4)
24766
+ return SDValue();
24767
+
24768
+ // At the moment we're unlikely to see a fixed-width vector interleave as
24769
+ // we usually generate shuffles instead.
24770
+ EVT SubVecTy = ValueInterleaveOps[0].getValueType();
24771
+ if (!SubVecTy.isScalableVT() ||
24772
+ SubVecTy.getSizeInBits().getKnownMinValue() != 128 ||
24773
+ !DAG.getTargetLoweringInfo().isTypeLegal(SubVecTy))
24774
+ return SDValue();
24775
+
24776
+ SDLoc DL(N);
24777
+ SDValue NarrowMask =
24778
+ getNarrowMaskForInterleavedOps(DAG, DL, MST->getMask(), NumParts);
24779
+ if (!NarrowMask)
24780
+ return SDValue();
24781
+
24782
+ const Intrinsic::ID IID =
24783
+ NumParts == 2 ? Intrinsic::aarch64_sve_st2 : Intrinsic::aarch64_sve_st4;
24784
+ SmallVector<SDValue, 8> NewStOps;
24785
+ NewStOps.append({MST->getChain(), DAG.getConstant(IID, DL, MVT::i32)});
24786
+ NewStOps.append(ValueInterleaveOps);
24787
+ NewStOps.append({NarrowMask, MST->getBasePtr()});
24788
+ return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, NewStOps);
24789
+ }
24790
+
24692
24791
static SDValue performMSTORECombine(SDNode *N,
24693
24792
TargetLowering::DAGCombinerInfo &DCI,
24694
24793
SelectionDAG &DAG,
@@ -24698,6 +24797,9 @@ static SDValue performMSTORECombine(SDNode *N,
24698
24797
SDValue Mask = MST->getMask();
24699
24798
SDLoc DL(N);
24700
24799
24800
+ if (SDValue Res = performInterleavedMaskedStoreCombine(N, DCI, DAG))
24801
+ return Res;
24802
+
24701
24803
// If this is a UZP1 followed by a masked store, fold this into a masked
24702
24804
// truncating store. We can do this even if this is already a masked
24703
24805
// truncstore.
@@ -27331,43 +27433,11 @@ static SDValue performVectorDeinterleaveCombine(
27331
27433
return SDValue();
27332
27434
27333
27435
// Now prove that the mask is an interleave of identical masks.
27334
- SDValue Mask = MaskedLoad->getMask();
27335
- if (Mask->getOpcode() != ISD::SPLAT_VECTOR &&
27336
- Mask->getOpcode() != ISD::CONCAT_VECTORS)
27337
- return SDValue();
27338
-
27339
- SDValue NarrowMask;
27340
27436
SDLoc DL(N);
27341
- if (Mask->getOpcode() == ISD::CONCAT_VECTORS) {
27342
- if (Mask->getNumOperands() != NumParts)
27343
- return SDValue();
27344
-
27345
- // We should be concatenating each sequential result from a
27346
- // VECTOR_INTERLEAVE.
27347
- SDNode *InterleaveOp = Mask->getOperand(0).getNode();
27348
- if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
27349
- InterleaveOp->getNumOperands() != NumParts)
27350
- return SDValue();
27351
-
27352
- for (unsigned I = 0; I < NumParts; I++) {
27353
- if (Mask.getOperand(I) != SDValue(InterleaveOp, I))
27354
- return SDValue();
27355
- }
27356
-
27357
- // Make sure the inputs to the vector interleave are identical.
27358
- if (!llvm::all_equal(InterleaveOp->op_values()))
27359
- return SDValue();
27360
-
27361
- NarrowMask = InterleaveOp->getOperand(0);
27362
- } else { // ISD::SPLAT_VECTOR
27363
- ElementCount EC = Mask.getValueType().getVectorElementCount();
27364
- assert(EC.isKnownMultipleOf(NumParts) &&
27365
- "Expected element count divisible by number of parts");
27366
- EC = EC.divideCoefficientBy(NumParts);
27367
- NarrowMask =
27368
- DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
27369
- Mask->getOperand(0));
27370
- }
27437
+ SDValue NarrowMask =
27438
+ getNarrowMaskForInterleavedOps(DAG, DL, MaskedLoad->getMask(), NumParts);
27439
+ if (!NarrowMask)
27440
+ return SDValue();
27371
27441
27372
27442
const Intrinsic::ID IID = NumParts == 2 ? Intrinsic::aarch64_sve_ld2_sret
27373
27443
: Intrinsic::aarch64_sve_ld4_sret;
0 commit comments