@@ -1179,6 +1179,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
11791179 setTargetDAGCombine(ISD::SCALAR_TO_VECTOR);
11801180
11811181 setTargetDAGCombine(ISD::SHL);
1182+ setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE);
11821183
11831184 // In case of strict alignment, avoid an excessive number of byte wide stores.
11841185 MaxStoresPerMemsetOptSize = 8;
@@ -27207,6 +27208,115 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
2720727208 return NVCAST;
2720827209}
2720927210
27211+ static SDValue performVectorDeinterleaveCombine(
27212+ SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
27213+ if (!DCI.isBeforeLegalize())
27214+ return SDValue();
27215+
27216+ unsigned NumParts = N->getNumOperands();
27217+ if (NumParts != 2 && NumParts != 4)
27218+ return SDValue();
27219+
27220+ EVT SubVecTy = N->getValueType(0);
27221+
27222+ // At the moment we're unlikely to see a fixed-width vector deinterleave as
27223+ // we usually generate shuffles instead.
27224+ unsigned MinNumElements = SubVecTy.getVectorMinNumElements();
27225+ if (!SubVecTy.isScalableVector() ||
27226+ SubVecTy.getSizeInBits().getKnownMinValue() != 128 ||
27227+ !DAG.getTargetLoweringInfo().isTypeLegal(SubVecTy))
27228+ return SDValue();
27229+
27230+ // Make sure each input operand is the correct extract_subvector of the same
27231+ // wider vector.
27232+ SDValue Op0 = N->getOperand(0);
27233+ for (unsigned I = 0; I < NumParts; I++) {
27234+ SDValue OpI = N->getOperand(I);
27235+ if (OpI->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
27236+ OpI->getOperand(0) != Op0->getOperand(0))
27237+ return SDValue();
27238+ if (OpI->getConstantOperandVal(1) != (I * MinNumElements))
27239+ return SDValue();
27240+ }
27241+
27242+ // Normal loads are currently already handled by the InterleavedAccessPass so
27243+ // we don't expect to see them here. Bail out if the masked load has an
27244+ // unexpected number of uses, since we want to avoid a situation where we have
27245+ // both deinterleaving loads and normal loads in the same block. Also, discard
27246+ // masked loads that are extending, indexed, have an unexpected offset or have
27247+ // an unsupported passthru value until we find a valid use case.
27248+ auto MaskedLoad = dyn_cast<MaskedLoadSDNode>(Op0->getOperand(0));
27249+ if (!MaskedLoad || !MaskedLoad->hasNUsesOfValue(NumParts, 0) ||
27250+ !MaskedLoad->isSimple() || !ISD::isNormalMaskedLoad(MaskedLoad) ||
27251+ !MaskedLoad->getOffset().isUndef() ||
27252+ (!MaskedLoad->getPassThru()->isUndef() &&
27253+ !isZerosVector(MaskedLoad->getPassThru().getNode())))
27254+ return SDValue();
27255+
27256+ // Now prove that the mask is an interleave of identical masks.
27257+ SDValue Mask = MaskedLoad->getMask();
27258+ if (Mask->getOpcode() != ISD::SPLAT_VECTOR &&
27259+ Mask->getOpcode() != ISD::CONCAT_VECTORS)
27260+ return SDValue();
27261+
27262+ SDValue NarrowMask;
27263+ SDLoc DL(N);
27264+ if (Mask->getOpcode() == ISD::CONCAT_VECTORS) {
27265+ if (Mask->getNumOperands() != NumParts)
27266+ return SDValue();
27267+
27268+ // We should be concatenating each sequential result from a
27269+ // VECTOR_INTERLEAVE.
27270+ SDNode *InterleaveOp = Mask->getOperand(0).getNode();
27271+ if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
27272+ InterleaveOp->getNumOperands() != NumParts)
27273+ return SDValue();
27274+
27275+ for (unsigned I = 0; I < NumParts; I++) {
27276+ if (Mask.getOperand(I) != SDValue(InterleaveOp, I))
27277+ return SDValue();
27278+ }
27279+
27280+ // Make sure the inputs to the vector interleave are identical.
27281+ if (!llvm::all_equal(InterleaveOp->op_values()))
27282+ return SDValue();
27283+
27284+ NarrowMask = InterleaveOp->getOperand(0);
27285+ } else { // ISD::SPLAT_VECTOR
27286+ ElementCount EC = Mask.getValueType().getVectorElementCount();
27287+ assert(EC.isKnownMultipleOf(NumParts) &&
27288+ "Expected element count divisible by number of parts");
27289+ EC = EC.divideCoefficientBy(NumParts);
27290+ NarrowMask =
27291+ DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
27292+ Mask->getOperand(0));
27293+ }
27294+
27295+ const Intrinsic::ID IID = NumParts == 2 ? Intrinsic::aarch64_sve_ld2_sret
27296+ : Intrinsic::aarch64_sve_ld4_sret;
27297+ SDValue NewLdOps[] = {MaskedLoad->getChain(),
27298+ DAG.getConstant(IID, DL, MVT::i32), NarrowMask,
27299+ MaskedLoad->getBasePtr()};
27300+ SDValue Res;
27301+ if (NumParts == 2)
27302+ Res = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL,
27303+ {SubVecTy, SubVecTy, MVT::Other}, NewLdOps);
27304+ else
27305+ Res = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL,
27306+ {SubVecTy, SubVecTy, SubVecTy, SubVecTy, MVT::Other},
27307+ NewLdOps);
27308+
27309+ // We can now generate a structured load!
27310+ SmallVector<SDValue, 4> ResOps(NumParts);
27311+ for (unsigned Idx = 0; Idx < NumParts; Idx++)
27312+ ResOps[Idx] = SDValue(Res.getNode(), Idx);
27313+
27314+ // Replace uses of the original chain result with the new chain result.
27315+ DAG.ReplaceAllUsesOfValueWith(SDValue(MaskedLoad, 1),
27316+ SDValue(Res.getNode(), NumParts));
27317+ return DCI.CombineTo(N, ResOps, false);
27318+ }
27319+
2721027320/// If the operand is a bitwise AND with a constant RHS, and the shift has a
2721127321/// constant RHS and is the only use, we can pull it out of the shift, i.e.
2721227322///
@@ -27275,6 +27385,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2727527385 default:
2727627386 LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
2727727387 break;
27388+ case ISD::VECTOR_DEINTERLEAVE:
27389+ return performVectorDeinterleaveCombine(N, DCI, DAG);
2727827390 case ISD::VECREDUCE_AND:
2727927391 case ISD::VECREDUCE_OR:
2728027392 case ISD::VECREDUCE_XOR:
0 commit comments