@@ -5711,46 +5711,6 @@ SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
57115711 Mask);
57125712}
57135713
5714- static unsigned getIntrinsicID(const SDNode *N);
5715-
5716- SDValue TryLowerMultiVecSMEDotIntrinsic(SDValue Op, SelectionDAG &DAG,
5717- unsigned Size) {
5718- assert((Size == 2 || Size == 4) && "Invalid Tuple Size");
5719- auto IsStridedLoad = [Size](SDValue Op) -> bool {
5720- unsigned Intrinsic = getIntrinsicID(Op.getNode());
5721- if (Size == 2)
5722- return Intrinsic == Intrinsic::aarch64_sve_ld1_pn_x2;
5723- else
5724- return Intrinsic == Intrinsic::aarch64_sve_ld1_pn_x4;
5725- };
5726-
5727- SmallVector<SDValue> Ops;
5728- unsigned LastLoadIdx = Size == 2 ? 5 : 7;
5729- unsigned LoadResNo = Op.getOperand(3).getResNo();
5730- for (unsigned I = 3; I < LastLoadIdx; I++) {
5731- if (!IsStridedLoad(Op->getOperand(I)) ||
5732- Op.getOperand(I).getResNo() != LoadResNo)
5733- return SDValue();
5734- Ops.push_back(Op->getOperand(I));
5735- }
5736-
5737- EVT VT = Op->getOperand(3).getValueType();
5738- SDVTList VTList =
5739- Size == 2 ? DAG.getVTList(VT, VT) : DAG.getVTList(VT, VT, VT, VT);
5740- unsigned Opc = Size == 2 ? AArch64ISD::FORM_STRIDED_TUPLE_X2
5741- : AArch64ISD::FORM_STRIDED_TUPLE_X4;
5742- SDLoc DL(Op);
5743- SDValue Pseudo = DAG.getNode(Opc, DL, VTList, Ops);
5744-
5745- SmallVector<SDValue> DotOps = {Op.getOperand(0), Op->getOperand(1),
5746- Op->getOperand(2)};
5747- for (unsigned I = 0; I < Size; I++)
5748- DotOps.push_back(Pseudo.getValue(I));
5749- DotOps.push_back(Op->getOperand(DotOps.size()));
5750- DotOps.push_back(Op->getOperand(DotOps.size()));
5751- return DAG.getNode(Op->getOpcode(), DL, MVT::Other, DotOps);
5752- }
5753-
57545714// Lower an SME LDR/STR ZA intrinsic
57555715// Case 1: If the vector number (vecnum) is an immediate in range, it gets
57565716// folded into the instruction
@@ -5940,22 +5900,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
59405900 Op->getOperand(0), // Chain
59415901 DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
59425902 DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
5943- case Intrinsic::aarch64_sme_uvdot_lane_za32_vg1x4:
5944- case Intrinsic::aarch64_sme_suvdot_lane_za32_vg1x4:
5945- case Intrinsic::aarch64_sme_usvdot_lane_za32_vg1x4:
5946- case Intrinsic::aarch64_sme_svdot_lane_za32_vg1x4:
5947- case Intrinsic::aarch64_sme_usdot_lane_za32_vg1x4:
5948- case Intrinsic::aarch64_sme_udot_lane_za32_vg1x4:
5949- case Intrinsic::aarch64_sme_sudot_lane_za32_vg1x4:
5950- case Intrinsic::aarch64_sme_sdot_lane_za32_vg1x4:
5951- return TryLowerMultiVecSMEDotIntrinsic(Op, DAG, 4);
5952- case Intrinsic::aarch64_sme_uvdot_lane_za32_vg1x2:
5953- case Intrinsic::aarch64_sme_sdot_lane_za32_vg1x2:
5954- case Intrinsic::aarch64_sme_svdot_lane_za32_vg1x2:
5955- case Intrinsic::aarch64_sme_usdot_lane_za32_vg1x2:
5956- case Intrinsic::aarch64_sme_sudot_lane_za32_vg1x2:
5957- case Intrinsic::aarch64_sme_udot_lane_za32_vg1x2:
5958- return TryLowerMultiVecSMEDotIntrinsic(Op, DAG, 2);
59595903 }
59605904}
59615905
@@ -8729,6 +8673,77 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
87298673 }
87308674 }
87318675
8676+ if (MI.getOpcode() == AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO ||
8677+ MI.getOpcode() == AArch64::FORM_STRIDED_TUPLE_X4_PSEUDO) {
8678+ MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
8679+ bool UseFormStrided = false;
8680+ unsigned Size =
8681+ MI.getOpcode() == AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO ? 2 : 4;
8682+
8683+ // The FORM_STRIDED_TUPLE pseudo should only be used if the input operands
8684+ // are copy nodes where the source register is in a StridedOrContiguous
8685+ // class. For example:
8686+ // %3:zpr2stridedorcontiguous = LD1B_2Z_IMM_PSEUDO ..
8687+ // %4:zpr = COPY %3.zsub1:zpr2stridedorcontiguous
8688+ // %5:zpr = COPY %3.zsub0:zpr2stridedorcontiguous
8689+ // %6:zpr2stridedorcontiguous = LD1B_2Z_PSEUDO ..
8690+ // %7:zpr = COPY %6.zsub1:zpr2stridedorcontiguous
8691+ // %8:zpr = COPY %6.zsub0:zpr2stridedorcontiguous
8692+ // %9:zpr2mul2 = FORM_STRIDED_TUPLE_X2_PSEUDO %5:zpr, %8:zpr
8693+
8694+ SmallVector<unsigned, 4> OpSubRegs;
8695+ for (unsigned I = 1; I < MI.getNumOperands(); ++I) {
8696+ MachineOperand &MO = MI.getOperand(I);
8697+ if (!MO.isReg())
8698+ continue;
8699+
8700+ MachineOperand *Def = MRI.getOneDef(MO.getReg());
8701+ if (!Def || !Def->isReg() || !Def->getParent()->isCopy())
8702+ continue;
8703+
8704+ MachineInstr *Cpy = Def->getParent();
8705+ MachineOperand CpyOp = Cpy->getOperand(1);
8706+ if (!CpyOp.isReg())
8707+ continue;
8708+
8709+ MachineOperand *Ld = MRI.getOneDef(CpyOp.getReg());
8710+ OpSubRegs.push_back(CpyOp.getSubReg());
8711+ if (!Ld || !Ld->isReg())
8712+ continue;
8713+
8714+ const TargetRegisterClass *RegClass =
8715+ Size == 2 ? &AArch64::ZPR2StridedOrContiguousRegClass
8716+ : &AArch64::ZPR4StridedOrContiguousRegClass;
8717+
8718+ if (MRI.getRegClass(Ld->getReg()) == RegClass)
8719+ UseFormStrided = true;
8720+ }
8721+
8722+ // Ensure the operands all use the same subreg index.
8723+ if (!std::equal(OpSubRegs.begin(), OpSubRegs.end(), OpSubRegs.begin()))
8724+ UseFormStrided = false;
8725+
8726+ // If input values to the FORM_STRIDED_TUPLE pseudo aren't copies from a
8727+ // StridedOrContiguous class, fall back on REG_SEQUENCE node.
8728+ if (!UseFormStrided) {
8729+ static const unsigned SubRegs[] = {AArch64::zsub0, AArch64::zsub1,
8730+ AArch64::zsub2, AArch64::zsub3};
8731+
8732+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
8733+ MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(),
8734+ TII->get(TargetOpcode::REG_SEQUENCE),
8735+ MI.getOperand(0).getReg());
8736+
8737+ for (unsigned I = 1; I < MI.getNumOperands(); ++I) {
8738+ MIB.add(MI.getOperand(I));
8739+ MIB.addImm(SubRegs[I - 1]);
8740+ }
8741+
8742+ MI.eraseFromParent();
8743+ }
8744+ return;
8745+ }
8746+
87328747 // Add an implicit use of 'VG' for ADDXri/SUBXri, which are instructions that
87338748 // have nothing to do with VG, were it not that they are used to materialise a
87348749 // frame-address. If they contain a frame-index to a scalable vector, this
0 commit comments