Skip to content

Commit 1317083

Browse files
authored
[AArch64][SME] Support saving/restoring ZT0 in the MachineSMEABIPass (#166362)
This patch extends the MachineSMEABIPass to support ZT0. This is done with the addition of two new states: - `ACTIVE_ZT0_SAVED` * This is used when calling a function that shares ZA, but does not share ZT0 (i.e., no ZT0 attributes) * This state indicates ZT0 must be saved to the save slot, but ZA must remain on, with no lazy save setup - `LOCAL_COMMITTED` * This is used for saving ZT0 in functions without ZA state * This state indicates ZA is off and ZT0 has been saved * This state is general enough to support ZA, but the required transitions have not been implemented† To aid with readability, the state transitions have been reworked to a switch of `transitionFrom(<FromState>).to(<ToState>)`, rather than nested ifs, which helps manage more transitions. † This could be implemented to handle some cases of undefined behavior better.
1 parent dda15ad commit 1317083

File tree

7 files changed

+480
-108
lines changed

7 files changed

+480
-108
lines changed

llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,7 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
17171717
}
17181718
case AArch64::InOutZAUsePseudo:
17191719
case AArch64::RequiresZASavePseudo:
1720+
case AArch64::RequiresZT0SavePseudo:
17201721
case AArch64::SMEStateAllocPseudo:
17211722
case AArch64::COALESCER_BARRIER_FPR16:
17221723
case AArch64::COALESCER_BARRIER_FPR32:

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9642,6 +9642,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96429642
if (CallAttrs.requiresLazySave() ||
96439643
CallAttrs.requiresPreservingAllZAState())
96449644
ZAMarkerNode = AArch64ISD::REQUIRES_ZA_SAVE;
9645+
else if (CallAttrs.requiresPreservingZT0())
9646+
ZAMarkerNode = AArch64ISD::REQUIRES_ZT0_SAVE;
96459647
else if (CallAttrs.caller().hasZAState() ||
96469648
CallAttrs.caller().hasZT0State())
96479649
ZAMarkerNode = AArch64ISD::INOUT_ZA_USE;
@@ -9761,7 +9763,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
97619763

97629764
SDValue ZTFrameIdx;
97639765
MachineFrameInfo &MFI = MF.getFrameInfo();
9764-
bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
9766+
bool ShouldPreserveZT0 =
9767+
!UseNewSMEABILowering && CallAttrs.requiresPreservingZT0();
97659768

97669769
// If the caller has ZT0 state which will not be preserved by the callee,
97679770
// spill ZT0 before the call.
@@ -9774,7 +9777,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
97749777

97759778
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
97769779
// PSTATE.ZA before the call if there is no lazy-save active.
9777-
bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
9780+
bool DisableZA =
9781+
!UseNewSMEABILowering && CallAttrs.requiresDisablingZABeforeCall();
97789782
assert((!DisableZA || !RequiresLazySave) &&
97799783
"Lazy-save should have PSTATE.SM=1 on entry to the function");
97809784

@@ -10263,7 +10267,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
1026310267
getSMToggleCondition(CallAttrs));
1026410268
}
1026510269

10266-
if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
10270+
if (!UseNewSMEABILowering &&
10271+
(RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall()))
1026710272
// Unconditionally resume ZA.
1026810273
Result = DAG.getNode(
1026910274
AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
102102
let hasSideEffects = 1, isMeta = 1 in {
103103
def InOutZAUsePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
104104
def RequiresZASavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
105+
def RequiresZT0SavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
105106
}
106107

107108
def SMEStateAllocPseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
@@ -122,6 +123,11 @@ def AArch64_requires_za_save
122123
[SDNPHasChain, SDNPInGlue, SDNPOutGlue]>;
123124
def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>;
124125

126+
def AArch64_requires_zt0_save
127+
: SDNode<"AArch64ISD::REQUIRES_ZT0_SAVE", SDTypeProfile<0, 0, []>,
128+
[SDNPHasChain, SDNPInGlue, SDNPOutGlue]>;
129+
def : Pat<(AArch64_requires_zt0_save), (RequiresZT0SavePseudo)>;
130+
125131
def AArch64_sme_state_alloc
126132
: SDNode<"AArch64ISD::SME_STATE_ALLOC", SDTypeProfile<0, 0,[]>,
127133
[SDNPHasChain]>;

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 160 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,34 @@ using namespace llvm;
7272

7373
namespace {
7474

75-
enum ZAState {
75+
// Note: For agnostic ZA, we assume the function is always entered/exited in the
76+
// "ACTIVE" state -- this _may_ not be the case (since OFF is also a
77+
// possibility, but for the purpose of placing ZA saves/restores, that does not
78+
// matter).
79+
enum ZAState : uint8_t {
7680
// Any/unknown state (not valid)
7781
ANY = 0,
7882

7983
// ZA is in use and active (i.e. within the accumulator)
8084
ACTIVE,
8185

86+
// ZA is active, but ZT0 has been saved.
87+
// This handles the edge case of sharedZA && !sharesZT0.
88+
ACTIVE_ZT0_SAVED,
89+
8290
// A ZA save has been set up or committed (i.e. ZA is dormant or off)
91+
// If the function uses ZT0 it must also be saved.
8392
LOCAL_SAVED,
8493

94+
// ZA has been committed to the lazy save buffer of the current function.
95+
// If the function uses ZT0 it must also be saved.
96+
// ZA is off.
97+
LOCAL_COMMITTED,
98+
8599
// The ZA/ZT0 state on entry to the function.
86100
ENTRY,
87101

88-
// ZA is off
102+
// ZA is off.
89103
OFF,
90104

91105
// The number of ZA states (not a valid state)
@@ -164,6 +178,14 @@ class EmitContext {
164178
return AgnosticZABufferPtr;
165179
}
166180

181+
int getZT0SaveSlot(MachineFunction &MF) {
182+
if (ZT0SaveFI)
183+
return *ZT0SaveFI;
184+
MachineFrameInfo &MFI = MF.getFrameInfo();
185+
ZT0SaveFI = MFI.CreateSpillStackObject(64, Align(16));
186+
return *ZT0SaveFI;
187+
}
188+
167189
/// Returns true if the function must allocate a ZA save buffer on entry. This
168190
/// will be the case if, at any point in the function, a ZA save was emitted.
169191
bool needsSaveBuffer() const {
@@ -173,6 +195,7 @@ class EmitContext {
173195
}
174196

175197
private:
198+
std::optional<int> ZT0SaveFI;
176199
std::optional<int> TPIDR2BlockFI;
177200
Register AgnosticZABufferPtr = AArch64::NoRegister;
178201
};
@@ -184,8 +207,10 @@ class EmitContext {
184207
/// state would not be legal, as transitioning to it drops the content of ZA.
185208
static bool isLegalEdgeBundleZAState(ZAState State) {
186209
switch (State) {
187-
case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
188-
case ZAState::LOCAL_SAVED: // ZA state is saved on the stack.
210+
case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
211+
case ZAState::ACTIVE_ZT0_SAVED: // ZT0 is saved (ZA is active).
212+
case ZAState::LOCAL_SAVED: // ZA state may be saved on the stack.
213+
case ZAState::LOCAL_COMMITTED: // ZA state is saved on the stack.
189214
return true;
190215
default:
191216
return false;
@@ -199,7 +224,9 @@ StringRef getZAStateString(ZAState State) {
199224
switch (State) {
200225
MAKE_CASE(ZAState::ANY)
201226
MAKE_CASE(ZAState::ACTIVE)
227+
MAKE_CASE(ZAState::ACTIVE_ZT0_SAVED)
202228
MAKE_CASE(ZAState::LOCAL_SAVED)
229+
MAKE_CASE(ZAState::LOCAL_COMMITTED)
203230
MAKE_CASE(ZAState::ENTRY)
204231
MAKE_CASE(ZAState::OFF)
205232
default:
@@ -221,18 +248,39 @@ static bool isZAorZTRegOp(const TargetRegisterInfo &TRI,
221248
/// Returns the required ZA state needed before \p MI and an iterator pointing
222249
/// to where any code required to change the ZA state should be inserted.
223250
static std::pair<ZAState, MachineBasicBlock::iterator>
224-
getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI,
225-
bool ZAOffAtReturn) {
251+
getInstNeededZAState(const TargetRegisterInfo &TRI, MachineInstr &MI,
252+
SMEAttrs SMEFnAttrs) {
226253
MachineBasicBlock::iterator InsertPt(MI);
227254

255+
// Note: InOutZAUsePseudo, RequiresZASavePseudo, and RequiresZT0SavePseudo are
256+
// intended to mark the position immediately before a call. Due to
257+
// SelectionDAG constraints, these markers occur after the ADJCALLSTACKDOWN,
258+
// so we use std::prev(InsertPt) to get the position before the call.
259+
228260
if (MI.getOpcode() == AArch64::InOutZAUsePseudo)
229261
return {ZAState::ACTIVE, std::prev(InsertPt)};
230262

263+
// Note: If we need to save both ZA and ZT0 we use RequiresZASavePseudo.
231264
if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
232265
return {ZAState::LOCAL_SAVED, std::prev(InsertPt)};
233266

234-
if (MI.isReturn())
267+
// If we only need to save ZT0 there's two cases to consider:
268+
// 1. The function has ZA state (that we don't need to save).
269+
// - In this case we switch to the "ACTIVE_ZT0_SAVED" state.
270+
// This only saves ZT0.
271+
// 2. The function does not have ZA state
272+
// - In this case we switch to "LOCAL_COMMITTED" state.
273+
// This saves ZT0 and turns ZA off.
274+
if (MI.getOpcode() == AArch64::RequiresZT0SavePseudo) {
275+
return {SMEFnAttrs.hasZAState() ? ZAState::ACTIVE_ZT0_SAVED
276+
: ZAState::LOCAL_COMMITTED,
277+
std::prev(InsertPt)};
278+
}
279+
280+
if (MI.isReturn()) {
281+
bool ZAOffAtReturn = SMEFnAttrs.hasPrivateZAInterface();
235282
return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
283+
}
236284

237285
for (auto &MO : MI.operands()) {
238286
if (isZAorZTRegOp(TRI, MO))
@@ -280,6 +328,9 @@ struct MachineSMEABI : public MachineFunctionPass {
280328
/// predecessors).
281329
void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);
282330

331+
void emitZT0SaveRestore(EmitContext &, MachineBasicBlock &MBB,
332+
MachineBasicBlock::iterator MBBI, bool IsSave);
333+
283334
// Emission routines for private and shared ZA functions (using lazy saves).
284335
void emitSMEPrologue(MachineBasicBlock &MBB,
285336
MachineBasicBlock::iterator MBBI);
@@ -290,8 +341,8 @@ struct MachineSMEABI : public MachineFunctionPass {
290341
MachineBasicBlock::iterator MBBI);
291342
void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB,
292343
MachineBasicBlock::iterator MBBI);
293-
void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
294-
bool ClearTPIDR2);
344+
void emitZAMode(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
345+
bool ClearTPIDR2, bool On);
295346

296347
// Emission routines for agnostic ZA functions.
297348
void emitSetupFullZASave(MachineBasicBlock &MBB,
@@ -409,7 +460,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
409460
Block.FixedEntryState = ZAState::ENTRY;
410461
} else if (MBB.isEHPad()) {
411462
// EH entry block:
412-
Block.FixedEntryState = ZAState::LOCAL_SAVED;
463+
Block.FixedEntryState = ZAState::LOCAL_COMMITTED;
413464
}
414465

415466
LiveRegUnits LiveUnits(*TRI);
@@ -431,8 +482,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
431482
PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
432483
}
433484
// Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
434-
auto [NeededState, InsertPt] = getZAStateBeforeInst(
435-
*TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface());
485+
auto [NeededState, InsertPt] = getInstNeededZAState(*TRI, MI, SMEFnAttrs);
436486
assert((InsertPt == MBBI || isCallStartOpcode(InsertPt->getOpcode())) &&
437487
"Unexpected state change insertion point!");
438488
// TODO: Do something to avoid state changes where NZCV is live.
@@ -752,9 +802,9 @@ void MachineSMEABI::emitRestoreLazySave(EmitContext &Context,
752802
restorePhyRegSave(RegSave, MBB, MBBI, DL);
753803
}
754804

755-
void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
756-
MachineBasicBlock::iterator MBBI,
757-
bool ClearTPIDR2) {
805+
void MachineSMEABI::emitZAMode(MachineBasicBlock &MBB,
806+
MachineBasicBlock::iterator MBBI,
807+
bool ClearTPIDR2, bool On) {
758808
DebugLoc DL = getDebugLoc(MBB, MBBI);
759809

760810
if (ClearTPIDR2)
@@ -765,7 +815,7 @@ void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
765815
// Disable ZA.
766816
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
767817
.addImm(AArch64SVCR::SVCRZA)
768-
.addImm(0);
818+
.addImm(On ? 1 : 0);
769819
}
770820

771821
void MachineSMEABI::emitAllocateLazySaveBuffer(
@@ -891,6 +941,28 @@ void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
891941
restorePhyRegSave(RegSave, MBB, MBBI, DL);
892942
}
893943

944+
void MachineSMEABI::emitZT0SaveRestore(EmitContext &Context,
945+
MachineBasicBlock &MBB,
946+
MachineBasicBlock::iterator MBBI,
947+
bool IsSave) {
948+
DebugLoc DL = getDebugLoc(MBB, MBBI);
949+
Register ZT0Save = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);
950+
951+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), ZT0Save)
952+
.addFrameIndex(Context.getZT0SaveSlot(*MF))
953+
.addImm(0)
954+
.addImm(0);
955+
956+
if (IsSave) {
957+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::STR_TX))
958+
.addReg(AArch64::ZT0)
959+
.addReg(ZT0Save);
960+
} else {
961+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDR_TX), AArch64::ZT0)
962+
.addReg(ZT0Save);
963+
}
964+
}
965+
894966
void MachineSMEABI::emitAllocateFullZASaveBuffer(
895967
EmitContext &Context, MachineBasicBlock &MBB,
896968
MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
@@ -935,6 +1007,17 @@ void MachineSMEABI::emitAllocateFullZASaveBuffer(
9351007
restorePhyRegSave(RegSave, MBB, MBBI, DL);
9361008
}
9371009

1010+
struct FromState {
1011+
ZAState From;
1012+
1013+
constexpr uint8_t to(ZAState To) const {
1014+
static_assert(NUM_ZA_STATE < 16, "expected ZAState to fit in 4-bits");
1015+
return uint8_t(From) << 4 | uint8_t(To);
1016+
}
1017+
};
1018+
1019+
constexpr FromState transitionFrom(ZAState From) { return FromState{From}; }
1020+
9381021
void MachineSMEABI::emitStateChange(EmitContext &Context,
9391022
MachineBasicBlock &MBB,
9401023
MachineBasicBlock::iterator InsertPt,
@@ -949,8 +1032,6 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
9491032
if (From == ZAState::ENTRY && To == ZAState::OFF)
9501033
return;
9511034

952-
[[maybe_unused]] SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
953-
9541035
// TODO: Avoid setting up the save buffer if there's no transition to
9551036
// LOCAL_SAVED.
9561037
if (From == ZAState::ENTRY) {
@@ -966,17 +1047,67 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
9661047
From = ZAState::ACTIVE;
9671048
}
9681049

969-
if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
970-
emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
971-
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
972-
emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
973-
else if (To == ZAState::OFF) {
974-
assert(From != ZAState::ENTRY &&
975-
"ENTRY to OFF should have already been handled");
976-
assert(!SMEFnAttrs.hasAgnosticZAInterface() &&
977-
"Should not turn ZA off in agnostic ZA function");
978-
emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
979-
} else {
1050+
SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
1051+
bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
1052+
bool HasZT0State = SMEFnAttrs.hasZT0State();
1053+
bool HasZAState = IsAgnosticZA || SMEFnAttrs.hasZAState();
1054+
1055+
switch (transitionFrom(From).to(To)) {
1056+
// This section handles: ACTIVE <-> ACTIVE_ZT0_SAVED
1057+
case transitionFrom(ZAState::ACTIVE).to(ZAState::ACTIVE_ZT0_SAVED):
1058+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1059+
break;
1060+
case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::ACTIVE):
1061+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
1062+
break;
1063+
1064+
// This section handles: ACTIVE[_ZT0_SAVED] -> LOCAL_SAVED
1065+
case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_SAVED):
1066+
case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::LOCAL_SAVED):
1067+
if (HasZT0State && From == ZAState::ACTIVE)
1068+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1069+
if (HasZAState)
1070+
emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
1071+
break;
1072+
1073+
// This section handles: ACTIVE -> LOCAL_COMMITTED
1074+
case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_COMMITTED):
1075+
// TODO: We could support ZA state here, but this transition is currently
1076+
// only possible when we _don't_ have ZA state.
1077+
assert(HasZT0State && !HasZAState && "Expect to only have ZT0 state.");
1078+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1079+
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/false);
1080+
break;
1081+
1082+
// This section handles: LOCAL_COMMITTED -> (OFF|LOCAL_SAVED)
1083+
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::OFF):
1084+
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::LOCAL_SAVED):
1085+
// These transistions are a no-op.
1086+
break;
1087+
1088+
// This section handles: LOCAL_(SAVED|COMMITTED) -> ACTIVE[_ZT0_SAVED]
1089+
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE):
1090+
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE_ZT0_SAVED):
1091+
case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::ACTIVE):
1092+
if (HasZAState)
1093+
emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
1094+
else
1095+
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/true);
1096+
if (HasZT0State && To == ZAState::ACTIVE)
1097+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
1098+
break;
1099+
1100+
// This section handles transistions to OFF (not previously covered)
1101+
case transitionFrom(ZAState::ACTIVE).to(ZAState::OFF):
1102+
case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::OFF):
1103+
case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::OFF):
1104+
assert(SMEFnAttrs.hasPrivateZAInterface() &&
1105+
"Did not expect to turn ZA off in shared/agnostic ZA function");
1106+
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED,
1107+
/*On=*/false);
1108+
break;
1109+
1110+
default:
9801111
dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
9811112
<< getZAStateString(To) << '\n';
9821113
llvm_unreachable("Unimplemented state transition");

0 commit comments

Comments
 (0)