Skip to content

Commit 0d4c931

Browse files
[AArch64][SME2] Add FORM_STRIDED_TUPLE pseudo nodes
This patch adds a pseudo node to help towards improving register allocation of multi-vector SME intrinsics. The FORM_STRIDED_TUPLE node is emitted if each of the operands of a contiguous multi-vector dot intrinsic are the result of a strided multi-vector load. The operands of the psuedo will be one subregister at the same index from each of these strided loads. Follow up patches will use this pseudo when adding register allocation hints to remove unecessary register copies in this scenario. Subregister liveness is also required to achieve this and has been enabled in the tests changed by this patch. Patch contains changes by Matthew Devereau.
1 parent 6992da2 commit 0d4c931

File tree

8 files changed

+378
-240
lines changed

8 files changed

+378
-240
lines changed

llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ class AArch64ExpandPseudo : public MachineFunctionPass {
6767
TargetRegisterClass ContiguousClass,
6868
TargetRegisterClass StridedClass,
6969
unsigned ContiguousOpc, unsigned StridedOpc);
70+
bool expandFormTuplePseudo(MachineBasicBlock &MBB,
71+
MachineBasicBlock::iterator MBBI,
72+
MachineBasicBlock::iterator &NextMBBI,
73+
unsigned Size);
7074
bool expandMOVImm(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
7175
unsigned BitSize);
7276

@@ -1142,6 +1146,30 @@ bool AArch64ExpandPseudo::expandMultiVecPseudo(
11421146
return true;
11431147
}
11441148

1149+
bool AArch64ExpandPseudo::expandFormTuplePseudo(
1150+
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
1151+
MachineBasicBlock::iterator &NextMBBI, unsigned Size) {
1152+
assert(Size == 2 || Size == 4 && "Invalid Tuple Size");
1153+
MachineInstr &MI = *MBBI;
1154+
Register ReturnTuple = MI.getOperand(0).getReg();
1155+
1156+
const TargetRegisterInfo *TRI =
1157+
MBB.getParent()->getSubtarget().getRegisterInfo();
1158+
for (unsigned i = 0; i < Size; i++) {
1159+
Register FormTupleOpReg = MI.getOperand(i + 1).getReg();
1160+
Register ReturnTupleSubReg =
1161+
TRI->getSubReg(ReturnTuple, AArch64::zsub0 + i);
1162+
if (FormTupleOpReg != ReturnTupleSubReg)
1163+
BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(AArch64::ORR_ZZZ))
1164+
.addReg(ReturnTupleSubReg, RegState::Define)
1165+
.addReg(FormTupleOpReg)
1166+
.addReg(FormTupleOpReg);
1167+
}
1168+
1169+
MI.eraseFromParent();
1170+
return true;
1171+
}
1172+
11451173
/// If MBBI references a pseudo instruction that should be expanded here,
11461174
/// do the expansion and return true. Otherwise return false.
11471175
bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
@@ -1724,6 +1752,10 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
17241752
return expandMultiVecPseudo(
17251753
MBB, MBBI, AArch64::ZPR4RegClass, AArch64::ZPR4StridedRegClass,
17261754
AArch64::LDNT1D_4Z, AArch64::LDNT1D_4Z_STRIDED);
1755+
case AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO:
1756+
return expandFormTuplePseudo(MBB, MBBI, NextMBBI, 2);
1757+
case AArch64::FORM_STRIDED_TUPLE_X4_PSEUDO:
1758+
return expandFormTuplePseudo(MBB, MBBI, NextMBBI, 4);
17271759
}
17281760
return false;
17291761
}

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,8 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
504504

505505
bool SelectAllActivePredicate(SDValue N);
506506
bool SelectAnyPredicate(SDValue N);
507+
508+
void SelectFormTuplePseudo(SDNode *N, unsigned Size);
507509
};
508510

509511
class AArch64DAGToDAGISelLegacy : public SelectionDAGISelLegacy {
@@ -7181,6 +7183,14 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
71817183
}
71827184
break;
71837185
}
7186+
case AArch64ISD::FORM_STRIDED_TUPLE_X2: {
7187+
SelectFormTuplePseudo(Node, 2);
7188+
return;
7189+
}
7190+
case AArch64ISD::FORM_STRIDED_TUPLE_X4: {
7191+
SelectFormTuplePseudo(Node, 4);
7192+
return;
7193+
}
71847194
}
71857195

71867196
// Select the default instruction
@@ -7438,3 +7448,20 @@ bool AArch64DAGToDAGISel::SelectSMETileSlice(SDValue N, unsigned MaxSize,
74387448
Offset = CurDAG->getTargetConstant(0, SDLoc(N), MVT::i64);
74397449
return true;
74407450
}
7451+
7452+
void AArch64DAGToDAGISel::SelectFormTuplePseudo(SDNode *Node, unsigned Size) {
7453+
assert((Size == 2 || Size == 4) && "Invalid Tuple size");
7454+
EVT VT = Node->getValueType(0);
7455+
SmallVector<SDValue> Ops;
7456+
for (unsigned I = 0; I < Size; I++)
7457+
Ops.push_back(Node->getOperand(I));
7458+
SDLoc DL(Node);
7459+
unsigned Opc = Size == 2 ? AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO
7460+
: AArch64::FORM_STRIDED_TUPLE_X4_PSEUDO;
7461+
SDNode *Tuple = CurDAG->getMachineNode(Opc, DL, MVT::Untyped, Ops);
7462+
SDValue SuperReg = SDValue(Tuple, 0);
7463+
for (unsigned I = 0; I < Size; ++I)
7464+
ReplaceUses(SDValue(Node, I), CurDAG->getTargetExtractSubreg(
7465+
AArch64::zsub0 + I, DL, VT, SuperReg));
7466+
CurDAG->RemoveDeadNode(Node);
7467+
}

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2808,6 +2808,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
28082808
MAKE_CASE(AArch64ISD::FMUL_PRED)
28092809
MAKE_CASE(AArch64ISD::FSUB_PRED)
28102810
MAKE_CASE(AArch64ISD::RDSVL)
2811+
MAKE_CASE(AArch64ISD::FORM_STRIDED_TUPLE_X2)
2812+
MAKE_CASE(AArch64ISD::FORM_STRIDED_TUPLE_X4)
28112813
MAKE_CASE(AArch64ISD::BIC)
28122814
MAKE_CASE(AArch64ISD::CBZ)
28132815
MAKE_CASE(AArch64ISD::CBNZ)
@@ -5709,6 +5711,46 @@ SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
57095711
Mask);
57105712
}
57115713

5714+
static unsigned getIntrinsicID(const SDNode *N);
5715+
5716+
SDValue TryLowerMultiVecSMEDotIntrinsic(SDValue Op, SelectionDAG &DAG,
5717+
unsigned Size) {
5718+
assert((Size == 2 || Size == 4) && "Invalid Tuple Size");
5719+
auto IsStridedLoad = [Size](SDValue Op) -> bool {
5720+
unsigned Intrinsic = getIntrinsicID(Op.getNode());
5721+
if (Size == 2)
5722+
return Intrinsic == Intrinsic::aarch64_sve_ld1_pn_x2;
5723+
else
5724+
return Intrinsic == Intrinsic::aarch64_sve_ld1_pn_x4;
5725+
};
5726+
5727+
SmallVector<SDValue> Ops;
5728+
unsigned LastLoadIdx = Size == 2 ? 5 : 7;
5729+
unsigned LoadResNo = Op.getOperand(3).getResNo();
5730+
for (unsigned I = 3; I < LastLoadIdx; I++) {
5731+
if (!IsStridedLoad(Op->getOperand(I)) ||
5732+
Op.getOperand(I).getResNo() != LoadResNo)
5733+
return SDValue();
5734+
Ops.push_back(Op->getOperand(I));
5735+
}
5736+
5737+
EVT VT = Op->getOperand(3).getValueType();
5738+
SDVTList VTList =
5739+
Size == 2 ? DAG.getVTList(VT, VT) : DAG.getVTList(VT, VT, VT, VT);
5740+
unsigned Opc = Size == 2 ? AArch64ISD::FORM_STRIDED_TUPLE_X2
5741+
: AArch64ISD::FORM_STRIDED_TUPLE_X4;
5742+
SDLoc DL(Op);
5743+
SDValue Pseudo = DAG.getNode(Opc, DL, VTList, Ops);
5744+
5745+
SmallVector<SDValue> DotOps = {Op.getOperand(0), Op->getOperand(1),
5746+
Op->getOperand(2)};
5747+
for (unsigned I = 0; I < Size; I++)
5748+
DotOps.push_back(Pseudo.getValue(I));
5749+
DotOps.push_back(Op->getOperand(DotOps.size()));
5750+
DotOps.push_back(Op->getOperand(DotOps.size()));
5751+
return DAG.getNode(Op->getOpcode(), DL, MVT::Other, DotOps);
5752+
}
5753+
57125754
// Lower an SME LDR/STR ZA intrinsic
57135755
// Case 1: If the vector number (vecnum) is an immediate in range, it gets
57145756
// folded into the instruction
@@ -5898,6 +5940,22 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
58985940
Op->getOperand(0), // Chain
58995941
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
59005942
DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
5943+
case Intrinsic::aarch64_sme_uvdot_lane_za32_vg1x4:
5944+
case Intrinsic::aarch64_sme_suvdot_lane_za32_vg1x4:
5945+
case Intrinsic::aarch64_sme_usvdot_lane_za32_vg1x4:
5946+
case Intrinsic::aarch64_sme_svdot_lane_za32_vg1x4:
5947+
case Intrinsic::aarch64_sme_usdot_lane_za32_vg1x4:
5948+
case Intrinsic::aarch64_sme_udot_lane_za32_vg1x4:
5949+
case Intrinsic::aarch64_sme_sudot_lane_za32_vg1x4:
5950+
case Intrinsic::aarch64_sme_sdot_lane_za32_vg1x4:
5951+
return TryLowerMultiVecSMEDotIntrinsic(Op, DAG, 4);
5952+
case Intrinsic::aarch64_sme_uvdot_lane_za32_vg1x2:
5953+
case Intrinsic::aarch64_sme_sdot_lane_za32_vg1x2:
5954+
case Intrinsic::aarch64_sme_svdot_lane_za32_vg1x2:
5955+
case Intrinsic::aarch64_sme_usdot_lane_za32_vg1x2:
5956+
case Intrinsic::aarch64_sme_sudot_lane_za32_vg1x2:
5957+
case Intrinsic::aarch64_sme_udot_lane_za32_vg1x2:
5958+
return TryLowerMultiVecSMEDotIntrinsic(Op, DAG, 2);
59015959
}
59025960
}
59035961

@@ -7639,6 +7697,11 @@ static unsigned getIntrinsicID(const SDNode *N) {
76397697
return IID;
76407698
return Intrinsic::not_intrinsic;
76417699
}
7700+
case ISD::INTRINSIC_W_CHAIN: {
7701+
unsigned IID = N->getConstantOperandVal(1);
7702+
if (IID < Intrinsic::num_intrinsics)
7703+
return IID;
7704+
}
76427705
}
76437706
}
76447707

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,9 @@ enum NodeType : unsigned {
478478
SME_ZA_LDR,
479479
SME_ZA_STR,
480480

481+
FORM_STRIDED_TUPLE_X2,
482+
FORM_STRIDED_TUPLE_X4,
483+
481484
// NEON Load/Store with post-increment base updates
482485
LD2post = ISD::FIRST_TARGET_MEMORY_OPCODE,
483486
LD3post,

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
2828
def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
2929
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
3030
[SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;
31+
32+
def SDT_FORM_STRIDED_TUPLE_X2 : SDTypeProfile<4, 4,
33+
[SDTCisVec<0>, SDTCisSameAs<0, 1>,
34+
SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>]>;
35+
36+
def SDT_FORM_STRIDED_TUPLE_X4 : SDTypeProfile<4, 4,
37+
[SDTCisVec<0>, SDTCisSameAs<0, 1>,
38+
SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>,
39+
SDTCisSameAs<0, 4>, SDTCisSameAs<0, 5>,
40+
SDTCisSameAs<0, 6>, SDTCisSameAs<0, 7>]>;
41+
3142
def AArch64CoalescerBarrier
3243
: SDNode<"AArch64ISD::COALESCER_BARRIER", SDTypeProfile<1, 1, []>, [SDNPOptInGlue, SDNPOutGlue]>;
3344

llvm/lib/Target/AArch64/SMEInstrFormats.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,18 @@ def tileslicerange0s4 : ComplexPattern<i32, 2, "SelectSMETileSlice<0, 4>", []>;
3434

3535
def am_sme_indexed_b4 :ComplexPattern<iPTR, 2, "SelectAddrModeIndexedSVE<0,15>", [], [SDNPWantRoot]>;
3636

37+
def FORM_STRIDED_TUPLE_X2_PSEUDO :
38+
Pseudo<(outs ZPR2Mul2:$tup),
39+
(ins ZPR:$zn0, ZPR:$zn1), []>, Sched<[]>{
40+
let hasSideEffects = 0;
41+
}
42+
43+
def FORM_STRIDED_TUPLE_X4_PSEUDO :
44+
Pseudo<(outs ZPR4Mul4:$tup),
45+
(ins ZPR:$zn0, ZPR:$zn1, ZPR:$zn2, ZPR:$zn3), []>, Sched<[]>{
46+
let hasSideEffects = 0;
47+
}
48+
3749
def SDTZALoadStore : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>]>;
3850
def AArch64SMELdr : SDNode<"AArch64ISD::SME_ZA_LDR", SDTZALoadStore,
3951
[SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;

0 commit comments

Comments
 (0)