Skip to content

Commit 7188a2d

Browse files
- Removed the FORM_STRIDED_TUPLE_X# nodes, leaving only the pseudo nodes.
- Changed the tablegen patterns used by the dot intrinsics to always output the FORM_STRIDED_TUPLE_X#_PSEUDO nodes. - Check that the operands to the pseudo are copies from a StridedOrContiguous register class in AdjustInstrPostInstrSelection, falling back on creating a REG_SEQUENCE node if not.
1 parent 0d4c931 commit 7188a2d

File tree

5 files changed

+93
-85
lines changed

5 files changed

+93
-85
lines changed

llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,6 +1159,8 @@ bool AArch64ExpandPseudo::expandFormTuplePseudo(
11591159
Register FormTupleOpReg = MI.getOperand(i + 1).getReg();
11601160
Register ReturnTupleSubReg =
11611161
TRI->getSubReg(ReturnTuple, AArch64::zsub0 + i);
1162+
// Add copies to ensure the subregisters remain in the correct order
1163+
// for any contigious operation they are used by.
11621164
if (FormTupleOpReg != ReturnTupleSubReg)
11631165
BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(AArch64::ORR_ZZZ))
11641166
.addReg(ReturnTupleSubReg, RegState::Define)

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
505505
bool SelectAllActivePredicate(SDValue N);
506506
bool SelectAnyPredicate(SDValue N);
507507

508-
void SelectFormTuplePseudo(SDNode *N, unsigned Size);
508+
void SelectFormTuplePseudo(SDNode *N);
509509
};
510510

511511
class AArch64DAGToDAGISelLegacy : public SelectionDAGISelLegacy {
@@ -7183,14 +7183,6 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
71837183
}
71847184
break;
71857185
}
7186-
case AArch64ISD::FORM_STRIDED_TUPLE_X2: {
7187-
SelectFormTuplePseudo(Node, 2);
7188-
return;
7189-
}
7190-
case AArch64ISD::FORM_STRIDED_TUPLE_X4: {
7191-
SelectFormTuplePseudo(Node, 4);
7192-
return;
7193-
}
71947186
}
71957187

71967188
// Select the default instruction
@@ -7448,20 +7440,3 @@ bool AArch64DAGToDAGISel::SelectSMETileSlice(SDValue N, unsigned MaxSize,
74487440
Offset = CurDAG->getTargetConstant(0, SDLoc(N), MVT::i64);
74497441
return true;
74507442
}
7451-
7452-
void AArch64DAGToDAGISel::SelectFormTuplePseudo(SDNode *Node, unsigned Size) {
7453-
assert((Size == 2 || Size == 4) && "Invalid Tuple size");
7454-
EVT VT = Node->getValueType(0);
7455-
SmallVector<SDValue> Ops;
7456-
for (unsigned I = 0; I < Size; I++)
7457-
Ops.push_back(Node->getOperand(I));
7458-
SDLoc DL(Node);
7459-
unsigned Opc = Size == 2 ? AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO
7460-
: AArch64::FORM_STRIDED_TUPLE_X4_PSEUDO;
7461-
SDNode *Tuple = CurDAG->getMachineNode(Opc, DL, MVT::Untyped, Ops);
7462-
SDValue SuperReg = SDValue(Tuple, 0);
7463-
for (unsigned I = 0; I < Size; ++I)
7464-
ReplaceUses(SDValue(Node, I), CurDAG->getTargetExtractSubreg(
7465-
AArch64::zsub0 + I, DL, VT, SuperReg));
7466-
CurDAG->RemoveDeadNode(Node);
7467-
}

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 71 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5711,46 +5711,6 @@ SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
57115711
Mask);
57125712
}
57135713

5714-
static unsigned getIntrinsicID(const SDNode *N);
5715-
5716-
SDValue TryLowerMultiVecSMEDotIntrinsic(SDValue Op, SelectionDAG &DAG,
5717-
unsigned Size) {
5718-
assert((Size == 2 || Size == 4) && "Invalid Tuple Size");
5719-
auto IsStridedLoad = [Size](SDValue Op) -> bool {
5720-
unsigned Intrinsic = getIntrinsicID(Op.getNode());
5721-
if (Size == 2)
5722-
return Intrinsic == Intrinsic::aarch64_sve_ld1_pn_x2;
5723-
else
5724-
return Intrinsic == Intrinsic::aarch64_sve_ld1_pn_x4;
5725-
};
5726-
5727-
SmallVector<SDValue> Ops;
5728-
unsigned LastLoadIdx = Size == 2 ? 5 : 7;
5729-
unsigned LoadResNo = Op.getOperand(3).getResNo();
5730-
for (unsigned I = 3; I < LastLoadIdx; I++) {
5731-
if (!IsStridedLoad(Op->getOperand(I)) ||
5732-
Op.getOperand(I).getResNo() != LoadResNo)
5733-
return SDValue();
5734-
Ops.push_back(Op->getOperand(I));
5735-
}
5736-
5737-
EVT VT = Op->getOperand(3).getValueType();
5738-
SDVTList VTList =
5739-
Size == 2 ? DAG.getVTList(VT, VT) : DAG.getVTList(VT, VT, VT, VT);
5740-
unsigned Opc = Size == 2 ? AArch64ISD::FORM_STRIDED_TUPLE_X2
5741-
: AArch64ISD::FORM_STRIDED_TUPLE_X4;
5742-
SDLoc DL(Op);
5743-
SDValue Pseudo = DAG.getNode(Opc, DL, VTList, Ops);
5744-
5745-
SmallVector<SDValue> DotOps = {Op.getOperand(0), Op->getOperand(1),
5746-
Op->getOperand(2)};
5747-
for (unsigned I = 0; I < Size; I++)
5748-
DotOps.push_back(Pseudo.getValue(I));
5749-
DotOps.push_back(Op->getOperand(DotOps.size()));
5750-
DotOps.push_back(Op->getOperand(DotOps.size()));
5751-
return DAG.getNode(Op->getOpcode(), DL, MVT::Other, DotOps);
5752-
}
5753-
57545714
// Lower an SME LDR/STR ZA intrinsic
57555715
// Case 1: If the vector number (vecnum) is an immediate in range, it gets
57565716
// folded into the instruction
@@ -5940,22 +5900,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
59405900
Op->getOperand(0), // Chain
59415901
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
59425902
DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
5943-
case Intrinsic::aarch64_sme_uvdot_lane_za32_vg1x4:
5944-
case Intrinsic::aarch64_sme_suvdot_lane_za32_vg1x4:
5945-
case Intrinsic::aarch64_sme_usvdot_lane_za32_vg1x4:
5946-
case Intrinsic::aarch64_sme_svdot_lane_za32_vg1x4:
5947-
case Intrinsic::aarch64_sme_usdot_lane_za32_vg1x4:
5948-
case Intrinsic::aarch64_sme_udot_lane_za32_vg1x4:
5949-
case Intrinsic::aarch64_sme_sudot_lane_za32_vg1x4:
5950-
case Intrinsic::aarch64_sme_sdot_lane_za32_vg1x4:
5951-
return TryLowerMultiVecSMEDotIntrinsic(Op, DAG, 4);
5952-
case Intrinsic::aarch64_sme_uvdot_lane_za32_vg1x2:
5953-
case Intrinsic::aarch64_sme_sdot_lane_za32_vg1x2:
5954-
case Intrinsic::aarch64_sme_svdot_lane_za32_vg1x2:
5955-
case Intrinsic::aarch64_sme_usdot_lane_za32_vg1x2:
5956-
case Intrinsic::aarch64_sme_sudot_lane_za32_vg1x2:
5957-
case Intrinsic::aarch64_sme_udot_lane_za32_vg1x2:
5958-
return TryLowerMultiVecSMEDotIntrinsic(Op, DAG, 2);
59595903
}
59605904
}
59615905

@@ -8729,6 +8673,77 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
87298673
}
87308674
}
87318675

8676+
if (MI.getOpcode() == AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO ||
8677+
MI.getOpcode() == AArch64::FORM_STRIDED_TUPLE_X4_PSEUDO) {
8678+
MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
8679+
bool UseFormStrided = false;
8680+
unsigned Size =
8681+
MI.getOpcode() == AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO ? 2 : 4;
8682+
8683+
// The FORM_STRIDED_TUPLE pseudo should only be used if the input operands
8684+
// are copy nodes where the source register is in a StridedOrContiguous
8685+
// class. For example:
8686+
// %3:zpr2stridedorcontiguous = LD1B_2Z_IMM_PSEUDO ..
8687+
// %4:zpr = COPY %3.zsub1:zpr2stridedorcontiguous
8688+
// %5:zpr = COPY %3.zsub0:zpr2stridedorcontiguous
8689+
// %6:zpr2stridedorcontiguous = LD1B_2Z_PSEUDO ..
8690+
// %7:zpr = COPY %6.zsub1:zpr2stridedorcontiguous
8691+
// %8:zpr = COPY %6.zsub0:zpr2stridedorcontiguous
8692+
// %9:zpr2mul2 = FORM_STRIDED_TUPLE_X2_PSEUDO %5:zpr, %8:zpr
8693+
8694+
SmallVector<unsigned, 4> OpSubRegs;
8695+
for (unsigned I = 1; I < MI.getNumOperands(); ++I) {
8696+
MachineOperand &MO = MI.getOperand(I);
8697+
if (!MO.isReg())
8698+
continue;
8699+
8700+
MachineOperand *Def = MRI.getOneDef(MO.getReg());
8701+
if (!Def || !Def->isReg() || !Def->getParent()->isCopy())
8702+
continue;
8703+
8704+
MachineInstr *Cpy = Def->getParent();
8705+
MachineOperand CpyOp = Cpy->getOperand(1);
8706+
if (!CpyOp.isReg())
8707+
continue;
8708+
8709+
MachineOperand *Ld = MRI.getOneDef(CpyOp.getReg());
8710+
OpSubRegs.push_back(CpyOp.getSubReg());
8711+
if (!Ld || !Ld->isReg())
8712+
continue;
8713+
8714+
const TargetRegisterClass *RegClass =
8715+
Size == 2 ? &AArch64::ZPR2StridedOrContiguousRegClass
8716+
: &AArch64::ZPR4StridedOrContiguousRegClass;
8717+
8718+
if (MRI.getRegClass(Ld->getReg()) == RegClass)
8719+
UseFormStrided = true;
8720+
}
8721+
8722+
// Ensure the operands all use the same subreg index.
8723+
if (!std::equal(OpSubRegs.begin(), OpSubRegs.end(), OpSubRegs.begin()))
8724+
UseFormStrided = false;
8725+
8726+
// If input values to the FORM_STRIDED_TUPLE pseudo aren't copies from a
8727+
// StridedOrContiguous class, fall back on REG_SEQUENCE node.
8728+
if (!UseFormStrided) {
8729+
static const unsigned SubRegs[] = {AArch64::zsub0, AArch64::zsub1,
8730+
AArch64::zsub2, AArch64::zsub3};
8731+
8732+
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
8733+
MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(),
8734+
TII->get(TargetOpcode::REG_SEQUENCE),
8735+
MI.getOperand(0).getReg());
8736+
8737+
for (unsigned I = 1; I < MI.getNumOperands(); ++I) {
8738+
MIB.add(MI.getOperand(I));
8739+
MIB.addImm(SubRegs[I - 1]);
8740+
}
8741+
8742+
MI.eraseFromParent();
8743+
}
8744+
return;
8745+
}
8746+
87328747
// Add an implicit use of 'VG' for ADDXri/SUBXri, which are instructions that
87338748
// have nothing to do with VG, were it not that they are used to materialise a
87348749
// frame-address. If they contain a frame-index to a scalable vector, this

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
2929
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
3030
[SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;
3131

32-
def SDT_FORM_STRIDED_TUPLE_X2 : SDTypeProfile<4, 4,
32+
def SDT_FORM_STRIDED_TUPLE_X2 : SDTypeProfile<2, 2,
3333
[SDTCisVec<0>, SDTCisSameAs<0, 1>,
3434
SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>]>;
3535

llvm/lib/Target/AArch64/SMEInstrFormats.td

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,14 @@ def FORM_STRIDED_TUPLE_X2_PSEUDO :
3838
Pseudo<(outs ZPR2Mul2:$tup),
3939
(ins ZPR:$zn0, ZPR:$zn1), []>, Sched<[]>{
4040
let hasSideEffects = 0;
41+
let hasPostISelHook = 1;
4142
}
4243

4344
def FORM_STRIDED_TUPLE_X4_PSEUDO :
4445
Pseudo<(outs ZPR4Mul4:$tup),
4546
(ins ZPR:$zn0, ZPR:$zn1, ZPR:$zn2, ZPR:$zn3), []>, Sched<[]>{
4647
let hasSideEffects = 0;
48+
let hasPostISelHook = 1;
4749
}
4850

4951
def SDTZALoadStore : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>]>;
@@ -186,6 +188,12 @@ class SME2_ZA_TwoOp_VG2_Multi_Index_Pat<string name, SDPatternOperator intrinsic
186188
(!cast<Instruction>(name # _PSEUDO) $base, $offset,
187189
(REG_SEQUENCE ZPR2Mul2, vt:$Zn1, zsub0, vt:$Zn2, zsub1), zpr_ty:$Zm, imm_ty:$i)>;
188190

191+
class SME2_ZA_TwoOp_VG2_Multi_Index_FormStrided_Pat<string name, SDPatternOperator intrinsic, Operand index_ty, ZPRRegOp zpr_ty, ValueType vt,
192+
Operand imm_ty, ComplexPattern tileslice>
193+
: Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)), vt:$Zn1, vt:$Zn2, vt:$Zm, (i32 imm_ty:$i)),
194+
(!cast<Instruction>(name # _PSEUDO) $base, $offset,
195+
(FORM_STRIDED_TUPLE_X2_PSEUDO vt:$Zn1,vt:$Zn2), zpr_ty:$Zm, imm_ty:$i)>;
196+
189197
class SME2_ZA_TwoOp_VG4_Multi_Index_Pat<string name, SDPatternOperator intrinsic, Operand index_ty, ZPRRegOp zpr_ty, ValueType vt,
190198
Operand imm_ty, ComplexPattern tileslice>
191199
: Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)),
@@ -194,6 +202,14 @@ class SME2_ZA_TwoOp_VG4_Multi_Index_Pat<string name, SDPatternOperator intrinsic
194202
(REG_SEQUENCE ZPR4Mul4, vt:$Zn1, zsub0, vt:$Zn2, zsub1, vt:$Zn3, zsub2, vt:$Zn4, zsub3),
195203
zpr_ty:$Zm, imm_ty:$i)>;
196204

205+
class SME2_ZA_TwoOp_VG4_Multi_Index_FormStrided_Pat<string name, SDPatternOperator intrinsic, Operand index_ty, ZPRRegOp zpr_ty, ValueType vt,
206+
Operand imm_ty, ComplexPattern tileslice>
207+
: Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)),
208+
vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4, vt:$Zm, (i32 imm_ty:$i)),
209+
(!cast<Instruction>(name # _PSEUDO) $base, $offset,
210+
(FORM_STRIDED_TUPLE_X4_PSEUDO vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4),
211+
zpr_ty:$Zm, imm_ty:$i)>;
212+
197213
class SME2_Sat_Shift_VG2_Pat<string name, SDPatternOperator intrinsic, ValueType out_vt, ValueType in_vt, Operand imm_ty>
198214
: Pat<(out_vt (intrinsic in_vt:$Zn1, in_vt:$Zn2, (i32 imm_ty:$i))),
199215
(!cast<Instruction>(name) (REG_SEQUENCE ZPR2Mul2, in_vt:$Zn1, zsub0, in_vt:$Zn2, zsub1), imm_ty:$i)>;
@@ -2635,7 +2651,7 @@ multiclass sme2_multi_vec_array_vg2_index_32b<string mnemonic, bits<2> sz, bits<
26352651
}
26362652
def _PSEUDO : sme2_za_array_2op_multi_index_pseudo<NAME, sme_elm_idx0_7, multi_vector_ty, vector_ty, VectorIndexS32b_timm, SMEMatrixArray>;
26372653

2638-
def : SME2_ZA_TwoOp_VG2_Multi_Index_Pat<NAME, intrinsic, sme_elm_idx0_7, vector_ty, vt, VectorIndexS32b_timm, tileslice16>;
2654+
def : SME2_ZA_TwoOp_VG2_Multi_Index_FormStrided_Pat<NAME, intrinsic, sme_elm_idx0_7, vector_ty, vt, VectorIndexS32b_timm, tileslice16>;
26392655

26402656
def : InstAlias<mnemonic # "\t$ZAda[$Rv, $imm3], $Zn, $Zm$i",
26412657
(!cast<Instruction>(NAME) MatrixOp32:$ZAda, MatrixIndexGPR32Op8_11:$Rv, sme_elm_idx0_7:$imm3,
@@ -2778,7 +2794,7 @@ multiclass sme2_multi_vec_array_vg4_index_32b<string mnemonic, bits<4> op,
27782794

27792795
def _PSEUDO : sme2_za_array_2op_multi_index_pseudo<NAME, sme_elm_idx0_7, multi_vector_ty, vector_ty, VectorIndexS32b_timm, SMEMatrixArray>;
27802796

2781-
def : SME2_ZA_TwoOp_VG4_Multi_Index_Pat<NAME, intrinsic, sme_elm_idx0_7, vector_ty, vt, VectorIndexS32b_timm, tileslice16>;
2797+
def : SME2_ZA_TwoOp_VG4_Multi_Index_FormStrided_Pat<NAME, intrinsic, sme_elm_idx0_7, vector_ty, vt, VectorIndexS32b_timm, tileslice16>;
27822798

27832799
def : InstAlias<mnemonic # "\t$ZAda[$Rv, $imm3], $Zn, $Zm$i",
27842800
(!cast<Instruction>(NAME) MatrixOp32:$ZAda, MatrixIndexGPR32Op8_11:$Rv, sme_elm_idx0_7:$imm3,

0 commit comments

Comments
 (0)