Skip to content

Commit ebc97b7

Browse files
[AArch64][SME2] Add register allocation hints for ZPRMulReg
This patch implements getRegAllocationHints to improve register allocation for the ZPR2Mul2Reg & ZPR4Mul4Reg classes. If a FORM_STRIDED_TUPLE is found, getRegAllocationHints will try to find a contiguous ZPRMulReg beginning with the same subregister as the first operand of the pseudo. For example, if the first strided load has been assigned $z16_z20_z24_z28 and the operands of the pseudo are each accessing subregister zsub2, the correct register to use would be $z24_z25_z26_z27.
1 parent 7188a2d commit ebc97b7

File tree

4 files changed

+124
-196
lines changed

4 files changed

+124
-196
lines changed

llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,81 @@ unsigned AArch64RegisterInfo::getRegPressureLimit(const TargetRegisterClass *RC,
11071107
}
11081108
}
11091109

1110+
// FORM_STRIDED_TUPLE nodes are created to improve register allocation where
1111+
// a consecutive multi-vector tuple is constructed from the same indices of
1112+
// multiple strided loads. This may still result in unnecessary copies between
1113+
// the loads and the tuple. Here we try to return a hint to assign the
1114+
// contiguous ZPRMulReg starting at the same register as the first operand of
1115+
// the pseudo, which should be a subregister of the first strided load.
1116+
//
1117+
// For example, if the first strided load has been assigned $z16_z20_z24_z28
1118+
// and the operands of the pseudo are each accessing subregister zsub2, we
1119+
// should look through through Order to find a contiguous register which
1120+
// begins with $z24 (i.e. $z24_z25_z26_z27).
1121+
//
1122+
bool AArch64RegisterInfo::getRegAllocationHints(
1123+
Register VirtReg, ArrayRef<MCPhysReg> Order,
1124+
SmallVectorImpl<MCPhysReg> &Hints, const MachineFunction &MF,
1125+
const VirtRegMap *VRM, const LiveRegMatrix *Matrix) const {
1126+
const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>();
1127+
const TargetRegisterInfo *TRI = STI.getRegisterInfo();
1128+
const MachineRegisterInfo &MRI = MF.getRegInfo();
1129+
bool DefaultHints =
1130+
TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints, MF, VRM);
1131+
1132+
unsigned RegID = MRI.getRegClass(VirtReg)->getID();
1133+
if (RegID != AArch64::ZPR2Mul2RegClassID &&
1134+
RegID != AArch64::ZPR4Mul4RegClassID)
1135+
return DefaultHints;
1136+
1137+
for (MachineInstr &MI : MRI.def_instructions(VirtReg)) {
1138+
if (MI.getOpcode() != AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO &&
1139+
MI.getOpcode() != AArch64::FORM_STRIDED_TUPLE_X4_PSEUDO)
1140+
continue;
1141+
1142+
// Look up the physical register mapped to the first load of the pseudo.
1143+
Register FirstLoadVirtReg = MI.getOperand(1).getReg();
1144+
if (!VRM->hasPhys(FirstLoadVirtReg))
1145+
continue;
1146+
1147+
unsigned SubRegIdx = 0;
1148+
MCRegister FirstLoadPhysReg = VRM->getPhys(FirstLoadVirtReg);
1149+
1150+
// The subreg number is used to access the correct unit of the
1151+
// strided register found in the map above.
1152+
switch (MI.getOperand(1).getSubReg()) {
1153+
case AArch64::zsub0:
1154+
break;
1155+
case AArch64::zsub1:
1156+
SubRegIdx = 1;
1157+
break;
1158+
case AArch64::zsub2:
1159+
SubRegIdx = 2;
1160+
break;
1161+
case AArch64::zsub3:
1162+
SubRegIdx = 3;
1163+
break;
1164+
default:
1165+
continue;
1166+
}
1167+
1168+
SmallVector<Register, 4> RegUnits;
1169+
for (MCRegUnit Unit : TRI->regunits(FirstLoadPhysReg))
1170+
RegUnits.push_back(Unit);
1171+
1172+
// Find the contiguous ZPRMul register which starts with the
1173+
// same register unit as the strided register and add to Hints.
1174+
Register StartReg = RegUnits[SubRegIdx];
1175+
for (unsigned I = 0; I < Order.size(); ++I) {
1176+
Register Reg = *TRI->regunits(Order[I]).begin();
1177+
if (Reg == StartReg)
1178+
Hints.push_back(Order[I]);
1179+
}
1180+
}
1181+
1182+
return DefaultHints;
1183+
}
1184+
11101185
unsigned AArch64RegisterInfo::getLocalAddressRegister(
11111186
const MachineFunction &MF) const {
11121187
const auto &MFI = MF.getFrameInfo();

llvm/lib/Target/AArch64/AArch64RegisterInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ class AArch64RegisterInfo final : public AArch64GenRegisterInfo {
134134
unsigned getRegPressureLimit(const TargetRegisterClass *RC,
135135
MachineFunction &MF) const override;
136136

137+
bool getRegAllocationHints(Register VirtReg, ArrayRef<MCPhysReg> Order,
138+
SmallVectorImpl<MCPhysReg> &Hints,
139+
const MachineFunction &MF, const VirtRegMap *VRM,
140+
const LiveRegMatrix *Matrix) const override;
141+
137142
unsigned getLocalAddressRegister(const MachineFunction &MF) const;
138143
bool regNeedsCFI(unsigned Reg, unsigned &RegToUseForCFI) const;
139144

llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll

Lines changed: 24 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -590,12 +590,8 @@ define void @udot_form_2x_tuple(ptr %ptr, i64 %stride) #0 {
590590
; CHECK-NEXT: mov w8, wzr
591591
; CHECK-NEXT: ld1b { z16.b, z24.b }, pn8/z, [x0]
592592
; CHECK-NEXT: ld1b { z17.b, z25.b }, pn8/z, [x0, x1]
593-
; CHECK-NEXT: mov z0.d, z16.d
594-
; CHECK-NEXT: mov z1.d, z17.d
595-
; CHECK-NEXT: udot za.s[w8, 0, vgx2], { z0.b, z1.b }, z0.b[0]
596-
; CHECK-NEXT: mov z0.d, z24.d
597-
; CHECK-NEXT: mov z1.d, z25.d
598-
; CHECK-NEXT: udot za.s[w8, 0, vgx2], { z0.b, z1.b }, z0.b[0]
593+
; CHECK-NEXT: udot za.s[w8, 0, vgx2], { z16.b, z17.b }, z0.b[0]
594+
; CHECK-NEXT: udot za.s[w8, 0, vgx2], { z24.b, z25.b }, z0.b[0]
599595
; CHECK-NEXT: ret
600596
entry:
601597
%0 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
@@ -622,26 +618,10 @@ define void @udot_form_4x_tuple(ptr %ptr, i64 %stride) #0 {
622618
; CHECK-NEXT: add x10, x9, x1
623619
; CHECK-NEXT: ld1b { z18.b, z22.b, z26.b, z30.b }, pn8/z, [x0, x9]
624620
; CHECK-NEXT: ld1b { z19.b, z23.b, z27.b, z31.b }, pn8/z, [x0, x10]
625-
; CHECK-NEXT: mov z0.d, z16.d
626-
; CHECK-NEXT: mov z1.d, z17.d
627-
; CHECK-NEXT: mov z2.d, z18.d
628-
; CHECK-NEXT: mov z3.d, z19.d
629-
; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
630-
; CHECK-NEXT: mov z0.d, z20.d
631-
; CHECK-NEXT: mov z1.d, z21.d
632-
; CHECK-NEXT: mov z2.d, z22.d
633-
; CHECK-NEXT: mov z3.d, z23.d
634-
; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
635-
; CHECK-NEXT: mov z0.d, z24.d
636-
; CHECK-NEXT: mov z1.d, z25.d
637-
; CHECK-NEXT: mov z2.d, z26.d
638-
; CHECK-NEXT: mov z3.d, z27.d
639-
; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
640-
; CHECK-NEXT: mov z0.d, z28.d
641-
; CHECK-NEXT: mov z1.d, z29.d
642-
; CHECK-NEXT: mov z2.d, z30.d
643-
; CHECK-NEXT: mov z3.d, z31.d
644-
; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
621+
; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z16.b - z19.b }, z0.b[0]
622+
; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z20.b - z23.b }, z0.b[0]
623+
; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z24.b - z27.b }, z0.b[0]
624+
; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z28.b - z31.b }, z0.b[0]
645625
; CHECK-NEXT: ret
646626
entry:
647627
%0 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
@@ -752,12 +732,8 @@ define void @usdot_form_2x_tuple(ptr %ptr, i64 %stride) #0 {
752732
; CHECK-NEXT: mov w8, wzr
753733
; CHECK-NEXT: ld1b { z16.b, z24.b }, pn8/z, [x0]
754734
; CHECK-NEXT: ld1b { z17.b, z25.b }, pn8/z, [x0, x1]
755-
; CHECK-NEXT: mov z0.d, z16.d
756-
; CHECK-NEXT: mov z1.d, z17.d
757-
; CHECK-NEXT: usdot za.s[w8, 0, vgx2], { z0.b, z1.b }, z0.b[0]
758-
; CHECK-NEXT: mov z0.d, z24.d
759-
; CHECK-NEXT: mov z1.d, z25.d
760-
; CHECK-NEXT: usdot za.s[w8, 0, vgx2], { z0.b, z1.b }, z0.b[0]
735+
; CHECK-NEXT: usdot za.s[w8, 0, vgx2], { z16.b, z17.b }, z0.b[0]
736+
; CHECK-NEXT: usdot za.s[w8, 0, vgx2], { z24.b, z25.b }, z0.b[0]
761737
; CHECK-NEXT: ret
762738
entry:
763739
%0 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
@@ -784,26 +760,10 @@ define void @usdot_form_4x_tuple(ptr %ptr, i64 %stride) #0 {
784760
; CHECK-NEXT: add x10, x9, x1
785761
; CHECK-NEXT: ld1b { z18.b, z22.b, z26.b, z30.b }, pn8/z, [x0, x9]
786762
; CHECK-NEXT: ld1b { z19.b, z23.b, z27.b, z31.b }, pn8/z, [x0, x10]
787-
; CHECK-NEXT: mov z0.d, z16.d
788-
; CHECK-NEXT: mov z1.d, z17.d
789-
; CHECK-NEXT: mov z2.d, z18.d
790-
; CHECK-NEXT: mov z3.d, z19.d
791-
; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
792-
; CHECK-NEXT: mov z0.d, z20.d
793-
; CHECK-NEXT: mov z1.d, z21.d
794-
; CHECK-NEXT: mov z2.d, z22.d
795-
; CHECK-NEXT: mov z3.d, z23.d
796-
; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
797-
; CHECK-NEXT: mov z0.d, z24.d
798-
; CHECK-NEXT: mov z1.d, z25.d
799-
; CHECK-NEXT: mov z2.d, z26.d
800-
; CHECK-NEXT: mov z3.d, z27.d
801-
; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
802-
; CHECK-NEXT: mov z0.d, z28.d
803-
; CHECK-NEXT: mov z1.d, z29.d
804-
; CHECK-NEXT: mov z2.d, z30.d
805-
; CHECK-NEXT: mov z3.d, z31.d
806-
; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
763+
; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z16.b - z19.b }, z0.b[0]
764+
; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z20.b - z23.b }, z0.b[0]
765+
; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z24.b - z27.b }, z0.b[0]
766+
; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z28.b - z31.b }, z0.b[0]
807767
; CHECK-NEXT: ret
808768
entry:
809769
%0 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
@@ -916,12 +876,8 @@ define void @sdot_form_2x_tuple(ptr %ptr, i64 %stride) #0 {
916876
; CHECK-NEXT: mov w8, wzr
917877
; CHECK-NEXT: ld1b { z16.b, z24.b }, pn8/z, [x0]
918878
; CHECK-NEXT: ld1b { z17.b, z25.b }, pn8/z, [x0, x1]
919-
; CHECK-NEXT: mov z0.d, z16.d
920-
; CHECK-NEXT: mov z1.d, z17.d
921-
; CHECK-NEXT: sdot za.s[w8, 0, vgx2], { z0.b, z1.b }, z0.b[0]
922-
; CHECK-NEXT: mov z0.d, z24.d
923-
; CHECK-NEXT: mov z1.d, z25.d
924-
; CHECK-NEXT: sdot za.s[w8, 0, vgx2], { z0.b, z1.b }, z0.b[0]
879+
; CHECK-NEXT: sdot za.s[w8, 0, vgx2], { z16.b, z17.b }, z0.b[0]
880+
; CHECK-NEXT: sdot za.s[w8, 0, vgx2], { z24.b, z25.b }, z0.b[0]
925881
; CHECK-NEXT: ret
926882
entry:
927883
%0 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
@@ -948,26 +904,10 @@ define void @sdot_form_4x_tuple(ptr %ptr, i64 %stride) #0 {
948904
; CHECK-NEXT: add x10, x9, x1
949905
; CHECK-NEXT: ld1b { z18.b, z22.b, z26.b, z30.b }, pn8/z, [x0, x9]
950906
; CHECK-NEXT: ld1b { z19.b, z23.b, z27.b, z31.b }, pn8/z, [x0, x10]
951-
; CHECK-NEXT: mov z0.d, z16.d
952-
; CHECK-NEXT: mov z1.d, z17.d
953-
; CHECK-NEXT: mov z2.d, z18.d
954-
; CHECK-NEXT: mov z3.d, z19.d
955-
; CHECK-NEXT: sdot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
956-
; CHECK-NEXT: mov z0.d, z20.d
957-
; CHECK-NEXT: mov z1.d, z21.d
958-
; CHECK-NEXT: mov z2.d, z22.d
959-
; CHECK-NEXT: mov z3.d, z23.d
960-
; CHECK-NEXT: sdot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
961-
; CHECK-NEXT: mov z0.d, z24.d
962-
; CHECK-NEXT: mov z1.d, z25.d
963-
; CHECK-NEXT: mov z2.d, z26.d
964-
; CHECK-NEXT: mov z3.d, z27.d
965-
; CHECK-NEXT: sdot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
966-
; CHECK-NEXT: mov z0.d, z28.d
967-
; CHECK-NEXT: mov z1.d, z29.d
968-
; CHECK-NEXT: mov z2.d, z30.d
969-
; CHECK-NEXT: mov z3.d, z31.d
970-
; CHECK-NEXT: sdot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
907+
; CHECK-NEXT: sdot za.s[w8, 0, vgx4], { z16.b - z19.b }, z0.b[0]
908+
; CHECK-NEXT: sdot za.s[w8, 0, vgx4], { z20.b - z23.b }, z0.b[0]
909+
; CHECK-NEXT: sdot za.s[w8, 0, vgx4], { z24.b - z27.b }, z0.b[0]
910+
; CHECK-NEXT: sdot za.s[w8, 0, vgx4], { z28.b - z31.b }, z0.b[0]
971911
; CHECK-NEXT: ret
972912
entry:
973913
%0 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
@@ -1080,12 +1020,8 @@ define void @sudot_form_2x_tuple(ptr %ptr, i64 %stride) #0 {
10801020
; CHECK-NEXT: mov w8, wzr
10811021
; CHECK-NEXT: ld1b { z16.b, z24.b }, pn8/z, [x0]
10821022
; CHECK-NEXT: ld1b { z17.b, z25.b }, pn8/z, [x0, x1]
1083-
; CHECK-NEXT: mov z0.d, z16.d
1084-
; CHECK-NEXT: mov z1.d, z17.d
1085-
; CHECK-NEXT: sudot za.s[w8, 0, vgx2], { z0.b, z1.b }, z0.b[0]
1086-
; CHECK-NEXT: mov z0.d, z24.d
1087-
; CHECK-NEXT: mov z1.d, z25.d
1088-
; CHECK-NEXT: sudot za.s[w8, 0, vgx2], { z0.b, z1.b }, z0.b[0]
1023+
; CHECK-NEXT: sudot za.s[w8, 0, vgx2], { z16.b, z17.b }, z0.b[0]
1024+
; CHECK-NEXT: sudot za.s[w8, 0, vgx2], { z24.b, z25.b }, z0.b[0]
10891025
; CHECK-NEXT: ret
10901026
entry:
10911027
%0 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
@@ -1112,26 +1048,10 @@ define void @sudot_form_4x_tuple(ptr %ptr, i64 %stride) #0 {
11121048
; CHECK-NEXT: add x10, x9, x1
11131049
; CHECK-NEXT: ld1b { z18.b, z22.b, z26.b, z30.b }, pn8/z, [x0, x9]
11141050
; CHECK-NEXT: ld1b { z19.b, z23.b, z27.b, z31.b }, pn8/z, [x0, x10]
1115-
; CHECK-NEXT: mov z0.d, z16.d
1116-
; CHECK-NEXT: mov z1.d, z17.d
1117-
; CHECK-NEXT: mov z2.d, z18.d
1118-
; CHECK-NEXT: mov z3.d, z19.d
1119-
; CHECK-NEXT: sudot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
1120-
; CHECK-NEXT: mov z0.d, z20.d
1121-
; CHECK-NEXT: mov z1.d, z21.d
1122-
; CHECK-NEXT: mov z2.d, z22.d
1123-
; CHECK-NEXT: mov z3.d, z23.d
1124-
; CHECK-NEXT: sudot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
1125-
; CHECK-NEXT: mov z0.d, z24.d
1126-
; CHECK-NEXT: mov z1.d, z25.d
1127-
; CHECK-NEXT: mov z2.d, z26.d
1128-
; CHECK-NEXT: mov z3.d, z27.d
1129-
; CHECK-NEXT: sudot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
1130-
; CHECK-NEXT: mov z0.d, z28.d
1131-
; CHECK-NEXT: mov z1.d, z29.d
1132-
; CHECK-NEXT: mov z2.d, z30.d
1133-
; CHECK-NEXT: mov z3.d, z31.d
1134-
; CHECK-NEXT: sudot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
1051+
; CHECK-NEXT: sudot za.s[w8, 0, vgx4], { z16.b - z19.b }, z0.b[0]
1052+
; CHECK-NEXT: sudot za.s[w8, 0, vgx4], { z20.b - z23.b }, z0.b[0]
1053+
; CHECK-NEXT: sudot za.s[w8, 0, vgx4], { z24.b - z27.b }, z0.b[0]
1054+
; CHECK-NEXT: sudot za.s[w8, 0, vgx4], { z28.b - z31.b }, z0.b[0]
11351055
; CHECK-NEXT: ret
11361056
entry:
11371057
%0 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()

0 commit comments

Comments
 (0)