@@ -72,20 +72,34 @@ using namespace llvm;
7272
7373namespace {
7474
75- enum ZAState {
75+ // Note: For agnostic ZA, we assume the function is always entered/exited in the
76+ // "ACTIVE" state -- this _may_ not be the case (since OFF is also a
77+ // possibility, but for the purpose of placing ZA saves/restores, that does not
78+ // matter).
79+ enum ZAState : uint8_t {
7680 // Any/unknown state (not valid)
7781 ANY = 0 ,
7882
7983 // ZA is in use and active (i.e. within the accumulator)
8084 ACTIVE,
8185
86+ // ZA is active, but ZT0 has been saved.
87+ // This handles the edge case of sharedZA && !sharesZT0.
88+ ACTIVE_ZT0_SAVED,
89+
8290 // A ZA save has been set up or committed (i.e. ZA is dormant or off)
91+ // If the function uses ZT0 it must also be saved.
8392 LOCAL_SAVED,
8493
94+ // ZA has been committed to the lazy save buffer of the current function.
95+ // If the function uses ZT0 it must also be saved.
96+ // ZA is off.
97+ LOCAL_COMMITTED,
98+
8599 // The ZA/ZT0 state on entry to the function.
86100 ENTRY,
87101
88- // ZA is off
102+ // ZA is off.
89103 OFF,
90104
91105 // The number of ZA states (not a valid state)
@@ -164,6 +178,14 @@ class EmitContext {
164178 return AgnosticZABufferPtr;
165179 }
166180
181+ int getZT0SaveSlot (MachineFunction &MF) {
182+ if (ZT0SaveFI)
183+ return *ZT0SaveFI;
184+ MachineFrameInfo &MFI = MF.getFrameInfo ();
185+ ZT0SaveFI = MFI.CreateSpillStackObject (64 , Align (16 ));
186+ return *ZT0SaveFI;
187+ }
188+
167189 // / Returns true if the function must allocate a ZA save buffer on entry. This
168190 // / will be the case if, at any point in the function, a ZA save was emitted.
169191 bool needsSaveBuffer () const {
@@ -173,6 +195,7 @@ class EmitContext {
173195 }
174196
175197private:
198+ std::optional<int > ZT0SaveFI;
176199 std::optional<int > TPIDR2BlockFI;
177200 Register AgnosticZABufferPtr = AArch64::NoRegister;
178201};
@@ -184,8 +207,10 @@ class EmitContext {
184207// / state would not be legal, as transitioning to it drops the content of ZA.
185208static bool isLegalEdgeBundleZAState (ZAState State) {
186209 switch (State) {
187- case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
188- case ZAState::LOCAL_SAVED: // ZA state is saved on the stack.
210+ case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
211+ case ZAState::ACTIVE_ZT0_SAVED: // ZT0 is saved (ZA is active).
212+ case ZAState::LOCAL_SAVED: // ZA state may be saved on the stack.
213+ case ZAState::LOCAL_COMMITTED: // ZA state is saved on the stack.
189214 return true ;
190215 default :
191216 return false ;
@@ -199,7 +224,9 @@ StringRef getZAStateString(ZAState State) {
199224 switch (State) {
200225 MAKE_CASE (ZAState::ANY)
201226 MAKE_CASE (ZAState::ACTIVE)
227+ MAKE_CASE (ZAState::ACTIVE_ZT0_SAVED)
202228 MAKE_CASE (ZAState::LOCAL_SAVED)
229+ MAKE_CASE (ZAState::LOCAL_COMMITTED)
203230 MAKE_CASE (ZAState::ENTRY)
204231 MAKE_CASE (ZAState::OFF)
205232 default :
@@ -221,18 +248,39 @@ static bool isZAorZTRegOp(const TargetRegisterInfo &TRI,
221248// / Returns the required ZA state needed before \p MI and an iterator pointing
222249// / to where any code required to change the ZA state should be inserted.
223250static std::pair<ZAState, MachineBasicBlock::iterator>
224- getZAStateBeforeInst (const TargetRegisterInfo &TRI, MachineInstr &MI,
225- bool ZAOffAtReturn ) {
251+ getInstNeededZAState (const TargetRegisterInfo &TRI, MachineInstr &MI,
252+ SMEAttrs SMEFnAttrs ) {
226253 MachineBasicBlock::iterator InsertPt (MI);
227254
255+ // Note: InOutZAUsePseudo, RequiresZASavePseudo, and RequiresZT0SavePseudo are
256+ // intended to mark the position immediately before a call. Due to
257+ // SelectionDAG constraints, these markers occur after the ADJCALLSTACKDOWN,
258+ // so we use std::prev(InsertPt) to get the position before the call.
259+
228260 if (MI.getOpcode () == AArch64::InOutZAUsePseudo)
229261 return {ZAState::ACTIVE, std::prev (InsertPt)};
230262
263+ // Note: If we need to save both ZA and ZT0 we use RequiresZASavePseudo.
231264 if (MI.getOpcode () == AArch64::RequiresZASavePseudo)
232265 return {ZAState::LOCAL_SAVED, std::prev (InsertPt)};
233266
234- if (MI.isReturn ())
267+ // If we only need to save ZT0 there's two cases to consider:
268+ // 1. The function has ZA state (that we don't need to save).
269+ // - In this case we switch to the "ACTIVE_ZT0_SAVED" state.
270+ // This only saves ZT0.
271+ // 2. The function does not have ZA state
272+ // - In this case we switch to "LOCAL_COMMITTED" state.
273+ // This saves ZT0 and turns ZA off.
274+ if (MI.getOpcode () == AArch64::RequiresZT0SavePseudo) {
275+ return {SMEFnAttrs.hasZAState () ? ZAState::ACTIVE_ZT0_SAVED
276+ : ZAState::LOCAL_COMMITTED,
277+ std::prev (InsertPt)};
278+ }
279+
280+ if (MI.isReturn ()) {
281+ bool ZAOffAtReturn = SMEFnAttrs.hasPrivateZAInterface ();
235282 return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
283+ }
236284
237285 for (auto &MO : MI.operands ()) {
238286 if (isZAorZTRegOp (TRI, MO))
@@ -280,6 +328,9 @@ struct MachineSMEABI : public MachineFunctionPass {
280328 // / predecessors).
281329 void propagateDesiredStates (FunctionInfo &FnInfo, bool Forwards = true );
282330
331+ void emitZT0SaveRestore (EmitContext &, MachineBasicBlock &MBB,
332+ MachineBasicBlock::iterator MBBI, bool IsSave);
333+
283334 // Emission routines for private and shared ZA functions (using lazy saves).
284335 void emitSMEPrologue (MachineBasicBlock &MBB,
285336 MachineBasicBlock::iterator MBBI);
@@ -290,8 +341,8 @@ struct MachineSMEABI : public MachineFunctionPass {
290341 MachineBasicBlock::iterator MBBI);
291342 void emitAllocateLazySaveBuffer (EmitContext &, MachineBasicBlock &MBB,
292343 MachineBasicBlock::iterator MBBI);
293- void emitZAOff (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
294- bool ClearTPIDR2);
344+ void emitZAMode (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
345+ bool ClearTPIDR2, bool On );
295346
296347 // Emission routines for agnostic ZA functions.
297348 void emitSetupFullZASave (MachineBasicBlock &MBB,
@@ -409,7 +460,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
409460 Block.FixedEntryState = ZAState::ENTRY;
410461 } else if (MBB.isEHPad ()) {
411462 // EH entry block:
412- Block.FixedEntryState = ZAState::LOCAL_SAVED ;
463+ Block.FixedEntryState = ZAState::LOCAL_COMMITTED ;
413464 }
414465
415466 LiveRegUnits LiveUnits (*TRI);
@@ -431,8 +482,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
431482 PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
432483 }
433484 // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
434- auto [NeededState, InsertPt] = getZAStateBeforeInst (
435- *TRI, MI, /* ZAOffAtReturn=*/ SMEFnAttrs.hasPrivateZAInterface ());
485+ auto [NeededState, InsertPt] = getInstNeededZAState (*TRI, MI, SMEFnAttrs);
436486 assert ((InsertPt == MBBI || isCallStartOpcode (InsertPt->getOpcode ())) &&
437487 " Unexpected state change insertion point!" );
438488 // TODO: Do something to avoid state changes where NZCV is live.
@@ -752,9 +802,9 @@ void MachineSMEABI::emitRestoreLazySave(EmitContext &Context,
752802 restorePhyRegSave (RegSave, MBB, MBBI, DL);
753803}
754804
755- void MachineSMEABI::emitZAOff (MachineBasicBlock &MBB,
756- MachineBasicBlock::iterator MBBI,
757- bool ClearTPIDR2) {
805+ void MachineSMEABI::emitZAMode (MachineBasicBlock &MBB,
806+ MachineBasicBlock::iterator MBBI,
807+ bool ClearTPIDR2, bool On ) {
758808 DebugLoc DL = getDebugLoc (MBB, MBBI);
759809
760810 if (ClearTPIDR2)
@@ -765,7 +815,7 @@ void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
765815 // Disable ZA.
766816 BuildMI (MBB, MBBI, DL, TII->get (AArch64::MSRpstatesvcrImm1))
767817 .addImm (AArch64SVCR::SVCRZA)
768- .addImm (0 );
818+ .addImm (On ? 1 : 0 );
769819}
770820
771821void MachineSMEABI::emitAllocateLazySaveBuffer (
@@ -891,6 +941,28 @@ void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
891941 restorePhyRegSave (RegSave, MBB, MBBI, DL);
892942}
893943
944+ void MachineSMEABI::emitZT0SaveRestore (EmitContext &Context,
945+ MachineBasicBlock &MBB,
946+ MachineBasicBlock::iterator MBBI,
947+ bool IsSave) {
948+ DebugLoc DL = getDebugLoc (MBB, MBBI);
949+ Register ZT0Save = MRI->createVirtualRegister (&AArch64::GPR64spRegClass);
950+
951+ BuildMI (MBB, MBBI, DL, TII->get (AArch64::ADDXri), ZT0Save)
952+ .addFrameIndex (Context.getZT0SaveSlot (*MF))
953+ .addImm (0 )
954+ .addImm (0 );
955+
956+ if (IsSave) {
957+ BuildMI (MBB, MBBI, DL, TII->get (AArch64::STR_TX))
958+ .addReg (AArch64::ZT0)
959+ .addReg (ZT0Save);
960+ } else {
961+ BuildMI (MBB, MBBI, DL, TII->get (AArch64::LDR_TX), AArch64::ZT0)
962+ .addReg (ZT0Save);
963+ }
964+ }
965+
894966void MachineSMEABI::emitAllocateFullZASaveBuffer (
895967 EmitContext &Context, MachineBasicBlock &MBB,
896968 MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
@@ -935,6 +1007,17 @@ void MachineSMEABI::emitAllocateFullZASaveBuffer(
9351007 restorePhyRegSave (RegSave, MBB, MBBI, DL);
9361008}
9371009
1010+ struct FromState {
1011+ ZAState From;
1012+
1013+ constexpr uint8_t to (ZAState To) const {
1014+ static_assert (NUM_ZA_STATE < 16 , " expected ZAState to fit in 4-bits" );
1015+ return uint8_t (From) << 4 | uint8_t (To);
1016+ }
1017+ };
1018+
1019+ constexpr FromState transitionFrom (ZAState From) { return FromState{From}; }
1020+
9381021void MachineSMEABI::emitStateChange (EmitContext &Context,
9391022 MachineBasicBlock &MBB,
9401023 MachineBasicBlock::iterator InsertPt,
@@ -949,8 +1032,6 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
9491032 if (From == ZAState::ENTRY && To == ZAState::OFF)
9501033 return ;
9511034
952- [[maybe_unused]] SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs ();
953-
9541035 // TODO: Avoid setting up the save buffer if there's no transition to
9551036 // LOCAL_SAVED.
9561037 if (From == ZAState::ENTRY) {
@@ -966,17 +1047,67 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
9661047 From = ZAState::ACTIVE;
9671048 }
9681049
969- if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
970- emitZASave (Context, MBB, InsertPt, PhysLiveRegs);
971- else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
972- emitZARestore (Context, MBB, InsertPt, PhysLiveRegs);
973- else if (To == ZAState::OFF) {
974- assert (From != ZAState::ENTRY &&
975- " ENTRY to OFF should have already been handled" );
976- assert (!SMEFnAttrs.hasAgnosticZAInterface () &&
977- " Should not turn ZA off in agnostic ZA function" );
978- emitZAOff (MBB, InsertPt, /* ClearTPIDR2=*/ From == ZAState::LOCAL_SAVED);
979- } else {
1050+ SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs ();
1051+ bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface ();
1052+ bool HasZT0State = SMEFnAttrs.hasZT0State ();
1053+ bool HasZAState = IsAgnosticZA || SMEFnAttrs.hasZAState ();
1054+
1055+ switch (transitionFrom (From).to (To)) {
1056+ // This section handles: ACTIVE <-> ACTIVE_ZT0_SAVED
1057+ case transitionFrom (ZAState::ACTIVE).to (ZAState::ACTIVE_ZT0_SAVED):
1058+ emitZT0SaveRestore (Context, MBB, InsertPt, /* IsSave=*/ true );
1059+ break ;
1060+ case transitionFrom (ZAState::ACTIVE_ZT0_SAVED).to (ZAState::ACTIVE):
1061+ emitZT0SaveRestore (Context, MBB, InsertPt, /* IsSave=*/ false );
1062+ break ;
1063+
1064+ // This section handles: ACTIVE[_ZT0_SAVED] -> LOCAL_SAVED
1065+ case transitionFrom (ZAState::ACTIVE).to (ZAState::LOCAL_SAVED):
1066+ case transitionFrom (ZAState::ACTIVE_ZT0_SAVED).to (ZAState::LOCAL_SAVED):
1067+ if (HasZT0State && From == ZAState::ACTIVE)
1068+ emitZT0SaveRestore (Context, MBB, InsertPt, /* IsSave=*/ true );
1069+ if (HasZAState)
1070+ emitZASave (Context, MBB, InsertPt, PhysLiveRegs);
1071+ break ;
1072+
1073+ // This section handles: ACTIVE -> LOCAL_COMMITTED
1074+ case transitionFrom (ZAState::ACTIVE).to (ZAState::LOCAL_COMMITTED):
1075+ // TODO: We could support ZA state here, but this transition is currently
1076+ // only possible when we _don't_ have ZA state.
1077+ assert (HasZT0State && !HasZAState && " Expect to only have ZT0 state." );
1078+ emitZT0SaveRestore (Context, MBB, InsertPt, /* IsSave=*/ true );
1079+ emitZAMode (MBB, InsertPt, /* ClearTPIDR2=*/ false , /* On=*/ false );
1080+ break ;
1081+
1082+ // This section handles: LOCAL_COMMITTED -> (OFF|LOCAL_SAVED)
1083+ case transitionFrom (ZAState::LOCAL_COMMITTED).to (ZAState::OFF):
1084+ case transitionFrom (ZAState::LOCAL_COMMITTED).to (ZAState::LOCAL_SAVED):
1085+ // These transistions are a no-op.
1086+ break ;
1087+
1088+ // This section handles: LOCAL_(SAVED|COMMITTED) -> ACTIVE[_ZT0_SAVED]
1089+ case transitionFrom (ZAState::LOCAL_COMMITTED).to (ZAState::ACTIVE):
1090+ case transitionFrom (ZAState::LOCAL_COMMITTED).to (ZAState::ACTIVE_ZT0_SAVED):
1091+ case transitionFrom (ZAState::LOCAL_SAVED).to (ZAState::ACTIVE):
1092+ if (HasZAState)
1093+ emitZARestore (Context, MBB, InsertPt, PhysLiveRegs);
1094+ else
1095+ emitZAMode (MBB, InsertPt, /* ClearTPIDR2=*/ false , /* On=*/ true );
1096+ if (HasZT0State && To == ZAState::ACTIVE)
1097+ emitZT0SaveRestore (Context, MBB, InsertPt, /* IsSave=*/ false );
1098+ break ;
1099+
1100+ // This section handles transistions to OFF (not previously covered)
1101+ case transitionFrom (ZAState::ACTIVE).to (ZAState::OFF):
1102+ case transitionFrom (ZAState::ACTIVE_ZT0_SAVED).to (ZAState::OFF):
1103+ case transitionFrom (ZAState::LOCAL_SAVED).to (ZAState::OFF):
1104+ assert (SMEFnAttrs.hasPrivateZAInterface () &&
1105+ " Did not expect to turn ZA off in shared/agnostic ZA function" );
1106+ emitZAMode (MBB, InsertPt, /* ClearTPIDR2=*/ From == ZAState::LOCAL_SAVED,
1107+ /* On=*/ false );
1108+ break ;
1109+
1110+ default :
9801111 dbgs () << " Error: Transition from " << getZAStateString (From) << " to "
9811112 << getZAStateString (To) << ' \n ' ;
9821113 llvm_unreachable (" Unimplemented state transition" );
0 commit comments