77// ===----------------------------------------------------------------------===//
88//
99// This pass implements the SME ABI requirements for ZA state. This includes
10- // implementing the lazy ZA state save schemes around calls.
10+ // implementing the lazy (and agnostic) ZA state save schemes around calls.
1111//
1212// ===----------------------------------------------------------------------===//
1313
@@ -128,7 +128,7 @@ struct MachineSMEABI : public MachineFunctionPass {
128128
129129 void collectNeededZAStates (MachineFunction &MF, SMEAttrs);
130130 void pickBundleZAStates (MachineFunction &MF);
131- void insertStateChanges (MachineFunction &MF);
131+ void insertStateChanges (MachineFunction &MF, bool IsAgnosticZA );
132132
133133 // Emission routines for private and shared ZA functions (using lazy saves).
134134 void emitNewZAPrologue (MachineBasicBlock &MBB,
@@ -143,11 +143,46 @@ struct MachineSMEABI : public MachineFunctionPass {
143143 void emitZAOff (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
144144 bool ClearTPIDR2);
145145
146+ // Emission routines for agnostic ZA functions.
147+ void emitSetupFullZASave (MachineBasicBlock &MBB,
148+ MachineBasicBlock::iterator MBBI,
149+ LiveRegs PhysLiveRegs);
150+ void emitFullZASaveRestore (MachineBasicBlock &MBB,
151+ MachineBasicBlock::iterator MBBI,
152+ LiveRegs PhysLiveRegs, bool IsSave);
153+ void emitAllocateFullZASaveBuffer (MachineBasicBlock &MBB,
154+ MachineBasicBlock::iterator MBBI,
155+ LiveRegs PhysLiveRegs);
156+
146157 void emitStateChange (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
147- ZAState From, ZAState To, LiveRegs PhysLiveRegs);
158+ ZAState From, ZAState To, LiveRegs PhysLiveRegs,
159+ bool IsAgnosticZA);
160+
161+ // Helpers for switching between lazy/full ZA save/restore routines.
162+ void emitZASave (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
163+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
164+ if (IsAgnosticZA)
165+ return emitFullZASaveRestore (MBB, MBBI, PhysLiveRegs, /* IsSave=*/ true );
166+ return emitSetupLazySave (MBB, MBBI);
167+ }
168+ void emitZARestore (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
169+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
170+ if (IsAgnosticZA)
171+ return emitFullZASaveRestore (MBB, MBBI, PhysLiveRegs, /* IsSave=*/ false );
172+ return emitRestoreLazySave (MBB, MBBI, PhysLiveRegs);
173+ }
174+ void emitAllocateZASaveBuffer (MachineBasicBlock &MBB,
175+ MachineBasicBlock::iterator MBBI,
176+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
177+ if (IsAgnosticZA)
178+ return emitAllocateFullZASaveBuffer (MBB, MBBI, PhysLiveRegs);
179+ return emitAllocateLazySaveBuffer (MBB, MBBI);
180+ }
148181
149182 TPIDR2State getTPIDR2Block (MachineFunction &MF);
150183
184+ Register getAgnosticZABufferPtr (MachineFunction &MF);
185+
151186private:
152187 struct InstInfo {
153188 ZAState NeededState{ZAState::ANY};
@@ -158,6 +193,7 @@ struct MachineSMEABI : public MachineFunctionPass {
158193 struct BlockInfo {
159194 ZAState FixedEntryState{ZAState::ANY};
160195 SmallVector<InstInfo> Insts;
196+ LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
161197 LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
162198 };
163199
@@ -167,6 +203,9 @@ struct MachineSMEABI : public MachineFunctionPass {
167203 SmallVector<ZAState> BundleStates;
168204 std::optional<TPIDR2State> TPIDR2Block;
169205 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
206+ Register AgnosticZABufferPtr = AArch64::NoRegister;
207+ LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
208+ bool HasFullZASaveRestore = false ;
170209 } State;
171210
172211 EdgeBundles *Bundles = nullptr ;
@@ -175,7 +214,8 @@ struct MachineSMEABI : public MachineFunctionPass {
175214void MachineSMEABI::collectNeededZAStates (MachineFunction &MF,
176215 SMEAttrs SMEFnAttrs) {
177216 const TargetRegisterInfo &TRI = *MF.getSubtarget ().getRegisterInfo ();
178- assert ((SMEFnAttrs.hasZT0State () || SMEFnAttrs.hasZAState ()) &&
217+ assert ((SMEFnAttrs.hasAgnosticZAInterface () || SMEFnAttrs.hasZT0State () ||
218+ SMEFnAttrs.hasZAState ()) &&
179219 " Expected function to have ZA/ZT0 state!" );
180220
181221 State.Blocks .resize (MF.getNumBlockIDs ());
@@ -209,6 +249,7 @@ void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
209249
210250 Block.PhysLiveRegsAtExit = GetPhysLiveRegs ();
211251 auto FirstTerminatorInsertPt = MBB.getFirstTerminator ();
252+ auto FirstNonPhiInsertPt = MBB.getFirstNonPHI ();
212253 for (MachineInstr &MI : reverse (MBB)) {
213254 MachineBasicBlock::iterator MBBI (MI);
214255 LiveUnits.stepBackward (MI);
@@ -219,15 +260,20 @@ void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
219260 // block setup.
220261 if (MI.getOpcode () == AArch64::SMEStateAllocPseudo) {
221262 State.AfterSMEProloguePt = MBBI;
263+ State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
222264 }
265+ // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
223266 auto [NeededState, InsertPt] = getInstNeededZAState (
224- TRI, MI, /* ZALiveAtReturn=*/ SMEFnAttrs.hasSharedZAInterface ());
267+ TRI, MI, /* ZALiveAtReturn=*/ SMEFnAttrs.hasSharedZAInterface () ||
268+ SMEFnAttrs.hasAgnosticZAInterface ());
225269 assert ((InsertPt == MBBI ||
226270 InsertPt->getOpcode () == AArch64::ADJCALLSTACKDOWN) &&
227271 " Unexpected state change insertion point!" );
228272 // TODO: Do something to avoid state changes where NZCV is live.
229273 if (MBBI == FirstTerminatorInsertPt)
230274 Block.PhysLiveRegsAtExit = PhysLiveRegs;
275+ if (MBBI == FirstNonPhiInsertPt)
276+ Block.PhysLiveRegsAtEntry = PhysLiveRegs;
231277 if (NeededState != ZAState::ANY)
232278 Block.Insts .push_back ({NeededState, InsertPt, PhysLiveRegs});
233279 }
@@ -294,7 +340,7 @@ void MachineSMEABI::pickBundleZAStates(MachineFunction &MF) {
294340 }
295341}
296342
297- void MachineSMEABI::insertStateChanges (MachineFunction &MF) {
343+ void MachineSMEABI::insertStateChanges (MachineFunction &MF, bool IsAgnosticZA ) {
298344 for (MachineBasicBlock &MBB : MF) {
299345 BlockInfo &Block = State.Blocks [MBB.getNumber ()];
300346 ZAState InState =
@@ -309,7 +355,7 @@ void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
309355 for (auto &Inst : Block.Insts ) {
310356 if (CurrentState != Inst.NeededState )
311357 emitStateChange (MBB, Inst.InsertPt , CurrentState, Inst.NeededState ,
312- Inst.PhysLiveRegs );
358+ Inst.PhysLiveRegs , IsAgnosticZA );
313359 CurrentState = Inst.NeededState ;
314360 }
315361
@@ -318,7 +364,7 @@ void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
318364
319365 if (CurrentState != OutState)
320366 emitStateChange (MBB, MBB.getFirstTerminator (), CurrentState, OutState,
321- Block.PhysLiveRegsAtExit );
367+ Block.PhysLiveRegsAtExit , IsAgnosticZA );
322368 }
323369}
324370
@@ -573,10 +619,98 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
573619 emitZeroZA (TII, DL, MBB, MBBI, /* Mask=*/ 0b11111111 );
574620}
575621
622+ Register MachineSMEABI::getAgnosticZABufferPtr (MachineFunction &MF) {
623+ if (State.AgnosticZABufferPtr != AArch64::NoRegister)
624+ return State.AgnosticZABufferPtr ;
625+ if (auto BufferPtr =
626+ MF.getInfo <AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer ();
627+ BufferPtr != AArch64::NoRegister)
628+ State.AgnosticZABufferPtr = BufferPtr;
629+ else
630+ State.AgnosticZABufferPtr =
631+ MF.getRegInfo ().createVirtualRegister (&AArch64::GPR64RegClass);
632+ return State.AgnosticZABufferPtr ;
633+ }
634+
635+ void MachineSMEABI::emitFullZASaveRestore (MachineBasicBlock &MBB,
636+ MachineBasicBlock::iterator MBBI,
637+ LiveRegs PhysLiveRegs, bool IsSave) {
638+ MachineFunction &MF = *MBB.getParent ();
639+ auto &Subtarget = MF.getSubtarget <AArch64Subtarget>();
640+ const AArch64RegisterInfo &TRI = *Subtarget.getRegisterInfo ();
641+ const TargetInstrInfo &TII = *Subtarget.getInstrInfo ();
642+ MachineRegisterInfo &MRI = MF.getRegInfo ();
643+
644+ State.HasFullZASaveRestore = true ;
645+ DebugLoc DL = getDebugLoc (MBB, MBBI);
646+ Register BufferPtr = AArch64::X0;
647+
648+ ScopedPhysRegSave ScopedPhysRegSave (MRI, TII, DL, MBB, MBBI, PhysLiveRegs);
649+
650+ // Copy the buffer pointer into X0.
651+ BuildMI (MBB, MBBI, DL, TII.get (TargetOpcode::COPY), BufferPtr)
652+ .addReg (getAgnosticZABufferPtr (MF));
653+
654+ // Call __arm_sme_save/__arm_sme_restore.
655+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::BL))
656+ .addReg (BufferPtr, RegState::Implicit)
657+ .addExternalSymbol (IsSave ? " __arm_sme_save" : " __arm_sme_restore" )
658+ .addRegMask (TRI.getCallPreservedMask (
659+ MF,
660+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
661+ }
662+
663+ void MachineSMEABI::emitAllocateFullZASaveBuffer (
664+ MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
665+ LiveRegs PhysLiveRegs) {
666+ MachineFunction &MF = *MBB.getParent ();
667+ MachineFrameInfo &MFI = MF.getFrameInfo ();
668+ auto &Subtarget = MF.getSubtarget <AArch64Subtarget>();
669+ const TargetInstrInfo &TII = *Subtarget.getInstrInfo ();
670+ MachineRegisterInfo &MRI = MF.getRegInfo ();
671+ auto *AFI = MF.getInfo <AArch64FunctionInfo>();
672+
673+ // Buffer already allocated in SelectionDAG.
674+ if (AFI->getEarlyAllocSMESaveBuffer ())
675+ return ;
676+
677+ DebugLoc DL = getDebugLoc (MBB, MBBI);
678+ Register BufferPtr = getAgnosticZABufferPtr (MF);
679+ Register BufferSize = MRI.createVirtualRegister (&AArch64::GPR64RegClass);
680+
681+ ScopedPhysRegSave ScopedPhysRegSave (MRI, TII, DL, MBB, MBBI, PhysLiveRegs);
682+
683+ // Calculate the SME state size.
684+ {
685+ const AArch64RegisterInfo *TRI = Subtarget.getRegisterInfo ();
686+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::BL))
687+ .addExternalSymbol (" __arm_sme_state_size" )
688+ .addReg (AArch64::X0, RegState::ImplicitDefine)
689+ .addRegMask (TRI->getCallPreservedMask (
690+ MF, CallingConv::
691+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
692+ BuildMI (MBB, MBBI, DL, TII.get (TargetOpcode::COPY), BufferSize)
693+ .addReg (AArch64::X0);
694+ }
695+
696+ // Allocate a buffer object of the size given __arm_sme_state_size.
697+ {
698+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::SUBXrx64), AArch64::SP)
699+ .addReg (AArch64::SP)
700+ .addReg (BufferSize)
701+ .addImm (AArch64_AM::getArithExtendImm (AArch64_AM::UXTX, 0 ));
702+ BuildMI (MBB, MBBI, DL, TII.get (TargetOpcode::COPY), BufferPtr)
703+ .addReg (AArch64::SP);
704+
705+ // We have just allocated a variable sized object, tell this to PEI.
706+ MFI.CreateVariableSizedObject (Align (16 ), nullptr );
707+ }
708+ }
709+
576710void MachineSMEABI::emitStateChange (MachineBasicBlock &MBB,
577711 MachineBasicBlock::iterator InsertPt,
578712 ZAState From, ZAState To,
579- LiveRegs PhysLiveRegs) {
713+ LiveRegs PhysLiveRegs, bool IsAgnosticZA ) {
580714
581715 // ZA not used.
582716 if (From == ZAState::ANY || To == ZAState::ANY)
@@ -603,10 +737,11 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
603737 }
604738
605739 if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
606- emitSetupLazySave (MBB, InsertPt);
740+ emitZASave (MBB, InsertPt, PhysLiveRegs, IsAgnosticZA );
607741 else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
608- emitRestoreLazySave (MBB, InsertPt, PhysLiveRegs);
742+ emitZARestore (MBB, InsertPt, PhysLiveRegs, IsAgnosticZA );
609743 else if (To == ZAState::OFF) {
744+ assert (!IsAgnosticZA && " Should not turn ZA off in agnostic ZA function" );
610745 // If we're exiting from the CALLER_DORMANT state that means this new ZA
611746 // function did not touch ZA (so ZA was never turned on).
612747 if (From != ZAState::CALLER_DORMANT)
@@ -629,7 +764,8 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
629764
630765 auto *AFI = MF.getInfo <AArch64FunctionInfo>();
631766 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs ();
632- if (!SMEFnAttrs.hasZAState () && !SMEFnAttrs.hasZT0State ())
767+ if (!SMEFnAttrs.hasZAState () && !SMEFnAttrs.hasZT0State () &&
768+ !SMEFnAttrs.hasAgnosticZAInterface ())
633769 return false ;
634770
635771 assert (MF.getRegInfo ().isSSA () && " Expected to be run on SSA form!" );
@@ -638,20 +774,27 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
638774 State = PassState{};
639775 Bundles = &getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles ();
640776
777+ bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface ();
778+
641779 collectNeededZAStates (MF, SMEFnAttrs);
642780 pickBundleZAStates (MF);
643- insertStateChanges (MF);
781+ insertStateChanges (MF, /* IsAgnosticZA= */ IsAgnosticZA );
644782
645783 // Allocate save buffer (if needed).
646- if (State.TPIDR2Block .has_value ()) {
784+ if (State.HasFullZASaveRestore || State. TPIDR2Block .has_value ()) {
647785 if (State.AfterSMEProloguePt ) {
648786 // Note: With inline stack probes the AfterSMEProloguePt may not be in the
649787 // entry block (due to the probing loop).
650- emitAllocateLazySaveBuffer (*(*State.AfterSMEProloguePt )->getParent (),
651- *State.AfterSMEProloguePt );
788+ emitAllocateZASaveBuffer (*(*State.AfterSMEProloguePt )->getParent (),
789+ *State.AfterSMEProloguePt ,
790+ State.PhysLiveRegsAfterSMEPrologue ,
791+ /* IsAgnosticZA=*/ IsAgnosticZA);
652792 } else {
653793 MachineBasicBlock &EntryBlock = MF.front ();
654- emitAllocateLazySaveBuffer (EntryBlock, EntryBlock.getFirstNonPHI ());
794+ emitAllocateZASaveBuffer (
795+ EntryBlock, EntryBlock.getFirstNonPHI (),
796+ State.Blocks [EntryBlock.getNumber ()].PhysLiveRegsAtEntry ,
797+ /* IsAgnosticZA=*/ IsAgnosticZA);
655798 }
656799 }
657800
0 commit comments