@@ -1179,6 +1179,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1179
1179
setTargetDAGCombine(ISD::SCALAR_TO_VECTOR);
1180
1180
1181
1181
setTargetDAGCombine(ISD::SHL);
1182
+ setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE);
1182
1183
1183
1184
// In case of strict alignment, avoid an excessive number of byte wide stores.
1184
1185
MaxStoresPerMemsetOptSize = 8;
@@ -27207,6 +27208,115 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
27207
27208
return NVCAST;
27208
27209
}
27209
27210
27211
+ static SDValue performVectorDeinterleaveCombine(
27212
+ SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
27213
+ if (!DCI.isBeforeLegalize())
27214
+ return SDValue();
27215
+
27216
+ unsigned NumParts = N->getNumOperands();
27217
+ if (NumParts != 2 && NumParts != 4)
27218
+ return SDValue();
27219
+
27220
+ EVT SubVecTy = N->getValueType(0);
27221
+
27222
+ // At the moment we're unlikely to see a fixed-width vector deinterleave as
27223
+ // we usually generate shuffles instead.
27224
+ unsigned MinNumElements = SubVecTy.getVectorMinNumElements();
27225
+ if (!SubVecTy.isScalableVector() ||
27226
+ SubVecTy.getSizeInBits().getKnownMinValue() != 128 ||
27227
+ !DAG.getTargetLoweringInfo().isTypeLegal(SubVecTy))
27228
+ return SDValue();
27229
+
27230
+ // Make sure each input operand is the correct extract_subvector of the same
27231
+ // wider vector.
27232
+ SDValue Op0 = N->getOperand(0);
27233
+ for (unsigned I = 0; I < NumParts; I++) {
27234
+ SDValue OpI = N->getOperand(I);
27235
+ if (OpI->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
27236
+ OpI->getOperand(0) != Op0->getOperand(0))
27237
+ return SDValue();
27238
+ if (OpI->getConstantOperandVal(1) != (I * MinNumElements))
27239
+ return SDValue();
27240
+ }
27241
+
27242
+ // Normal loads are currently already handled by the InterleavedAccessPass so
27243
+ // we don't expect to see them here. Bail out if the masked load has an
27244
+ // unexpected number of uses, since we want to avoid a situation where we have
27245
+ // both deinterleaving loads and normal loads in the same block. Also, discard
27246
+ // masked loads that are extending, indexed, have an unexpected offset or have
27247
+ // an unsupported passthru value until we find a valid use case.
27248
+ auto MaskedLoad = dyn_cast<MaskedLoadSDNode>(Op0->getOperand(0));
27249
+ if (!MaskedLoad || !MaskedLoad->hasNUsesOfValue(NumParts, 0) ||
27250
+ !MaskedLoad->isSimple() || !ISD::isNormalMaskedLoad(MaskedLoad) ||
27251
+ !MaskedLoad->getOffset().isUndef() ||
27252
+ (!MaskedLoad->getPassThru()->isUndef() &&
27253
+ !isZerosVector(MaskedLoad->getPassThru().getNode())))
27254
+ return SDValue();
27255
+
27256
+ // Now prove that the mask is an interleave of identical masks.
27257
+ SDValue Mask = MaskedLoad->getMask();
27258
+ if (Mask->getOpcode() != ISD::SPLAT_VECTOR &&
27259
+ Mask->getOpcode() != ISD::CONCAT_VECTORS)
27260
+ return SDValue();
27261
+
27262
+ SDValue NarrowMask;
27263
+ SDLoc DL(N);
27264
+ if (Mask->getOpcode() == ISD::CONCAT_VECTORS) {
27265
+ if (Mask->getNumOperands() != NumParts)
27266
+ return SDValue();
27267
+
27268
+ // We should be concatenating each sequential result from a
27269
+ // VECTOR_INTERLEAVE.
27270
+ SDNode *InterleaveOp = Mask->getOperand(0).getNode();
27271
+ if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
27272
+ InterleaveOp->getNumOperands() != NumParts)
27273
+ return SDValue();
27274
+
27275
+ for (unsigned I = 0; I < NumParts; I++) {
27276
+ if (Mask.getOperand(I) != SDValue(InterleaveOp, I))
27277
+ return SDValue();
27278
+ }
27279
+
27280
+ // Make sure the inputs to the vector interleave are identical.
27281
+ if (!llvm::all_equal(InterleaveOp->op_values()))
27282
+ return SDValue();
27283
+
27284
+ NarrowMask = InterleaveOp->getOperand(0);
27285
+ } else { // ISD::SPLAT_VECTOR
27286
+ ElementCount EC = Mask.getValueType().getVectorElementCount();
27287
+ assert(EC.isKnownMultipleOf(NumParts) &&
27288
+ "Expected element count divisible by number of parts");
27289
+ EC = EC.divideCoefficientBy(NumParts);
27290
+ NarrowMask =
27291
+ DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
27292
+ Mask->getOperand(0));
27293
+ }
27294
+
27295
+ const Intrinsic::ID IID = NumParts == 2 ? Intrinsic::aarch64_sve_ld2_sret
27296
+ : Intrinsic::aarch64_sve_ld4_sret;
27297
+ SDValue NewLdOps[] = {MaskedLoad->getChain(),
27298
+ DAG.getConstant(IID, DL, MVT::i32), NarrowMask,
27299
+ MaskedLoad->getBasePtr()};
27300
+ SDValue Res;
27301
+ if (NumParts == 2)
27302
+ Res = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL,
27303
+ {SubVecTy, SubVecTy, MVT::Other}, NewLdOps);
27304
+ else
27305
+ Res = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL,
27306
+ {SubVecTy, SubVecTy, SubVecTy, SubVecTy, MVT::Other},
27307
+ NewLdOps);
27308
+
27309
+ // We can now generate a structured load!
27310
+ SmallVector<SDValue, 4> ResOps(NumParts);
27311
+ for (unsigned Idx = 0; Idx < NumParts; Idx++)
27312
+ ResOps[Idx] = SDValue(Res.getNode(), Idx);
27313
+
27314
+ // Replace uses of the original chain result with the new chain result.
27315
+ DAG.ReplaceAllUsesOfValueWith(SDValue(MaskedLoad, 1),
27316
+ SDValue(Res.getNode(), NumParts));
27317
+ return DCI.CombineTo(N, ResOps, false);
27318
+ }
27319
+
27210
27320
/// If the operand is a bitwise AND with a constant RHS, and the shift has a
27211
27321
/// constant RHS and is the only use, we can pull it out of the shift, i.e.
27212
27322
///
@@ -27275,6 +27385,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
27275
27385
default:
27276
27386
LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
27277
27387
break;
27388
+ case ISD::VECTOR_DEINTERLEAVE:
27389
+ return performVectorDeinterleaveCombine(N, DCI, DAG);
27278
27390
case ISD::VECREDUCE_AND:
27279
27391
case ISD::VECREDUCE_OR:
27280
27392
case ISD::VECREDUCE_XOR:
0 commit comments