-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[AArch64][SME2] Improve register allocation of multi-vector SME intrinsics #116399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
fc0d224
6992da2
0d4c931
7188a2d
ebc97b7
43939b9
426253c
645e30b
d7ccfe1
0f0bc84
7f3e687
6cb5c5d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7639,6 +7639,11 @@ static unsigned getIntrinsicID(const SDNode *N) { | |||||||||
| return IID; | ||||||||||
| return Intrinsic::not_intrinsic; | ||||||||||
| } | ||||||||||
| case ISD::INTRINSIC_W_CHAIN: { | ||||||||||
| unsigned IID = N->getConstantOperandVal(1); | ||||||||||
| if (IID < Intrinsic::num_intrinsics) | ||||||||||
| return IID; | ||||||||||
| } | ||||||||||
|
||||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
|
|
@@ -8641,6 +8646,55 @@ static bool checkZExtBool(SDValue Arg, const SelectionDAG &DAG) { | |||||||||
| return ZExtBool; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| bool shouldUseFormStridedPseudo(MachineInstr &MI) { | ||||||||||
| MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); | ||||||||||
| bool UseFormStrided = false; | ||||||||||
| unsigned NumOperands = | ||||||||||
sdesmalen-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| MI.getOpcode() == AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO ? 2 : 4; | ||||||||||
sdesmalen-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
|
|
||||||||||
| // The FORM_STRIDED_TUPLE pseudo should only be used if the input operands | ||||||||||
| // are copy nodes where the source register is in a StridedOrContiguous | ||||||||||
| // class. For example: | ||||||||||
| // %3:zpr2stridedorcontiguous = LD1B_2Z_IMM_PSEUDO .. | ||||||||||
| // %4:zpr = COPY %3.zsub1:zpr2stridedorcontiguous | ||||||||||
| // %5:zpr = COPY %3.zsub0:zpr2stridedorcontiguous | ||||||||||
| // %6:zpr2stridedorcontiguous = LD1B_2Z_PSEUDO .. | ||||||||||
| // %7:zpr = COPY %6.zsub1:zpr2stridedorcontiguous | ||||||||||
| // %8:zpr = COPY %6.zsub0:zpr2stridedorcontiguous | ||||||||||
| // %9:zpr2mul2 = FORM_STRIDED_TUPLE_X2_PSEUDO %5:zpr, %8:zpr | ||||||||||
sdesmalen-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
|
|
||||||||||
| MCRegister SubReg = MCRegister::NoRegister; | ||||||||||
| for (unsigned I = 1; I < MI.getNumOperands(); ++I) { | ||||||||||
| MachineOperand &MO = MI.getOperand(I); | ||||||||||
kmclaughlin-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| assert(MO.isReg() && "Unexpected operand to FORM_STRIDED_TUPLE"); | ||||||||||
kmclaughlin-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
|
|
||||||||||
| MachineOperand *Def = MRI.getOneDef(MO.getReg()); | ||||||||||
| if (!Def || !Def->isReg() || !Def->getParent()->isCopy()) { | ||||||||||
sdesmalen-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| UseFormStrided = false; | ||||||||||
| break; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| MachineOperand CpyOp = Def->getParent()->getOperand(1); | ||||||||||
sdesmalen-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
kmclaughlin-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| MachineOperand *Ld = MRI.getOneDef(CpyOp.getReg()); | ||||||||||
sdesmalen-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| unsigned OpSubReg = CpyOp.getSubReg(); | ||||||||||
| if (SubReg == MCRegister::NoRegister) | ||||||||||
| SubReg = OpSubReg; | ||||||||||
| if (!Ld || !Ld->isReg() || OpSubReg != SubReg) { | ||||||||||
| UseFormStrided = false; | ||||||||||
| break; | ||||||||||
| } | ||||||||||
|
||||||||||
|
|
||||||||||
| const TargetRegisterClass *RegClass = | ||||||||||
| NumOperands == 2 ? &AArch64::ZPR2StridedOrContiguousRegClass | ||||||||||
| : &AArch64::ZPR4StridedOrContiguousRegClass; | ||||||||||
|
|
||||||||||
| if (MRI.getRegClass(Ld->getReg()) == RegClass) | ||||||||||
| UseFormStrided = true; | ||||||||||
|
||||||||||
| if (MRI.getRegClass(Ld->getReg()) == RegClass) | |
| UseFormStrided = true; | |
| if (MRI.getRegClass(Ld->getReg()) != RegClass) | |
| return false; |
sdesmalen-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than continue it's better to break instead, because if one of them is not a COPY then we don't need to process the other operands.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
| MIB.addImm(SubRegs[I - 1]); | |
| MIB.addImm(AArch64::zsub0 + (I-1)); |
Then you can remove SubRegs[].
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1107,6 +1107,69 @@ unsigned AArch64RegisterInfo::getRegPressureLimit(const TargetRegisterClass *RC, | |
| } | ||
| } | ||
|
|
||
| // FORM_STRIDED_TUPLE nodes are created to improve register allocation where | ||
| // a consecutive multi-vector tuple is constructed from the same indices of | ||
| // multiple strided loads. This may still result in unnecessary copies between | ||
| // the loads and the tuple. Here we try to return a hint to assign the | ||
| // contiguous ZPRMulReg starting at the same register as the first operand of | ||
| // the pseudo, which should be a subregister of the first strided load. | ||
| // | ||
| // For example, if the first strided load has been assigned $z16_z20_z24_z28 | ||
| // and the operands of the pseudo are each accessing subregister zsub2, we | ||
| // should look through through Order to find a contiguous register which | ||
| // begins with $z24 (i.e. $z24_z25_z26_z27). | ||
| // | ||
| bool AArch64RegisterInfo::getRegAllocationHints( | ||
| Register VirtReg, ArrayRef<MCPhysReg> Order, | ||
| SmallVectorImpl<MCPhysReg> &Hints, const MachineFunction &MF, | ||
| const VirtRegMap *VRM, const LiveRegMatrix *Matrix) const { | ||
| const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>(); | ||
| const TargetRegisterInfo *TRI = STI.getRegisterInfo(); | ||
|
||
| const MachineRegisterInfo &MRI = MF.getRegInfo(); | ||
| bool DefaultHints = | ||
| TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints, MF, VRM); | ||
|
||
|
|
||
| unsigned RegID = MRI.getRegClass(VirtReg)->getID(); | ||
| if (RegID != AArch64::ZPR2Mul2RegClassID && | ||
| RegID != AArch64::ZPR4Mul4RegClassID) | ||
| return DefaultHints; | ||
|
|
||
|
||
| for (MachineInstr &MI : MRI.def_instructions(VirtReg)) { | ||
| if (MI.getOpcode() != AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO && | ||
| MI.getOpcode() != AArch64::FORM_STRIDED_TUPLE_X4_PSEUDO) | ||
| continue; | ||
|
||
|
|
||
| // Look up the physical register mapped to the first load of the pseudo. | ||
| Register FirstLoadVirtReg = MI.getOperand(1).getReg(); | ||
|
||
| if (!VRM->hasPhys(FirstLoadVirtReg)) | ||
| continue; | ||
|
|
||
| int64_t SubRegIdx = -1; | ||
| MCRegister FirstLoadPhysReg = VRM->getPhys(FirstLoadVirtReg); | ||
|
|
||
| // The subreg number is used to access the correct unit of the | ||
| // strided register found in the map above. | ||
| SubRegIdx = MI.getOperand(1).getSubReg() - AArch64::zsub0; | ||
| if (SubRegIdx < 0 || SubRegIdx > 3) | ||
| continue; | ||
|
||
|
|
||
| SmallVector<Register, 4> RegUnits; | ||
| for (MCRegUnit Unit : TRI->regunits(FirstLoadPhysReg)) | ||
| RegUnits.push_back(Unit); | ||
|
|
||
| // Find the contiguous ZPRMul register which starts with the | ||
| // same register unit as the strided register and add to Hints. | ||
| Register StartReg = RegUnits[SubRegIdx]; | ||
| for (unsigned I = 0; I < Order.size(); ++I) { | ||
| Register Reg = *TRI->regunits(Order[I]).begin(); | ||
| if (Reg == StartReg) | ||
| Hints.push_back(Order[I]); | ||
| } | ||
|
||
| } | ||
|
|
||
| return DefaultHints; | ||
| } | ||
|
|
||
| unsigned AArch64RegisterInfo::getLocalAddressRegister( | ||
| const MachineFunction &MF) const { | ||
| const auto &MFI = MF.getFrameInfo(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,20 @@ def tileslicerange0s4 : ComplexPattern<i32, 2, "SelectSMETileSlice<0, 4>", []>; | |
|
|
||
| def am_sme_indexed_b4 :ComplexPattern<iPTR, 2, "SelectAddrModeIndexedSVE<0,15>", [], [SDNPWantRoot]>; | ||
|
|
||
| def FORM_STRIDED_TUPLE_X2_PSEUDO : | ||
| Pseudo<(outs ZPR2Mul2:$tup), | ||
| (ins ZPR:$zn0, ZPR:$zn1), []>, Sched<[]>{ | ||
| let hasSideEffects = 0; | ||
| let hasPostISelHook = 1; | ||
| } | ||
|
|
||
| def FORM_STRIDED_TUPLE_X4_PSEUDO : | ||
| Pseudo<(outs ZPR4Mul4:$tup), | ||
| (ins ZPR:$zn0, ZPR:$zn1, ZPR:$zn2, ZPR:$zn3), []>, Sched<[]>{ | ||
| let hasSideEffects = 0; | ||
| let hasPostISelHook = 1; | ||
| } | ||
|
|
||
| def SDTZALoadStore : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>]>; | ||
| def AArch64SMELdr : SDNode<"AArch64ISD::SME_ZA_LDR", SDTZALoadStore, | ||
| [SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>; | ||
|
|
@@ -172,14 +186,14 @@ class SME2_ZA_TwoOp_VG2_Multi_Index_Pat<string name, SDPatternOperator intrinsic | |
| Operand imm_ty, ComplexPattern tileslice> | ||
| : Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)), vt:$Zn1, vt:$Zn2, vt:$Zm, (i32 imm_ty:$i)), | ||
| (!cast<Instruction>(name # _PSEUDO) $base, $offset, | ||
| (REG_SEQUENCE ZPR2Mul2, vt:$Zn1, zsub0, vt:$Zn2, zsub1), zpr_ty:$Zm, imm_ty:$i)>; | ||
| (FORM_STRIDED_TUPLE_X2_PSEUDO vt:$Zn1,vt:$Zn2), zpr_ty:$Zm, imm_ty:$i)>; | ||
|
||
|
|
||
| class SME2_ZA_TwoOp_VG4_Multi_Index_Pat<string name, SDPatternOperator intrinsic, Operand index_ty, ZPRRegOp zpr_ty, ValueType vt, | ||
| Operand imm_ty, ComplexPattern tileslice> | ||
| : Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)), | ||
| vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4, vt:$Zm, (i32 imm_ty:$i)), | ||
| (!cast<Instruction>(name # _PSEUDO) $base, $offset, | ||
| (REG_SEQUENCE ZPR4Mul4, vt:$Zn1, zsub0, vt:$Zn2, zsub1, vt:$Zn3, zsub2, vt:$Zn4, zsub3), | ||
| (FORM_STRIDED_TUPLE_X4_PSEUDO vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4), | ||
| zpr_ty:$Zm, imm_ty:$i)>; | ||
|
|
||
| class SME2_Sat_Shift_VG2_Pat<string name, SDPatternOperator intrinsic, ValueType out_vt, ValueType in_vt, Operand imm_ty> | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: