77// ===----------------------------------------------------------------------===//
88//
99// This pass implements the SME ABI requirements for ZA state. This includes
10- // implementing the lazy ZA state save schemes around calls.
10+ // implementing the lazy (and agnostic) ZA state save schemes around calls.
1111//
1212// ===----------------------------------------------------------------------===//
1313//
@@ -200,7 +200,7 @@ struct MachineSMEABI : public MachineFunctionPass {
200200
201201 // / Inserts code to handle changes between ZA states within the function.
202202 // / E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
203- void insertStateChanges ();
203+ void insertStateChanges (bool IsAgnosticZA );
204204
205205 // Emission routines for private and shared ZA functions (using lazy saves).
206206 void emitNewZAPrologue (MachineBasicBlock &MBB,
@@ -215,8 +215,41 @@ struct MachineSMEABI : public MachineFunctionPass {
215215 void emitZAOff (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
216216 bool ClearTPIDR2);
217217
218+ // Emission routines for agnostic ZA functions.
219+ void emitSetupFullZASave (MachineBasicBlock &MBB,
220+ MachineBasicBlock::iterator MBBI,
221+ LiveRegs PhysLiveRegs);
222+ void emitFullZASaveRestore (MachineBasicBlock &MBB,
223+ MachineBasicBlock::iterator MBBI,
224+ LiveRegs PhysLiveRegs, bool IsSave);
225+ void emitAllocateFullZASaveBuffer (MachineBasicBlock &MBB,
226+ MachineBasicBlock::iterator MBBI,
227+ LiveRegs PhysLiveRegs);
228+
218229 void emitStateChange (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
219- ZAState From, ZAState To, LiveRegs PhysLiveRegs);
230+ ZAState From, ZAState To, LiveRegs PhysLiveRegs,
231+ bool IsAgnosticZA);
232+
233+ // Helpers for switching between lazy/full ZA save/restore routines.
234+ void emitZASave (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
235+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
236+ if (IsAgnosticZA)
237+ return emitFullZASaveRestore (MBB, MBBI, PhysLiveRegs, /* IsSave=*/ true );
238+ return emitSetupLazySave (MBB, MBBI);
239+ }
240+ void emitZARestore (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
241+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
242+ if (IsAgnosticZA)
243+ return emitFullZASaveRestore (MBB, MBBI, PhysLiveRegs, /* IsSave=*/ false );
244+ return emitRestoreLazySave (MBB, MBBI, PhysLiveRegs);
245+ }
246+ void emitAllocateZASaveBuffer (MachineBasicBlock &MBB,
247+ MachineBasicBlock::iterator MBBI,
248+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
249+ if (IsAgnosticZA)
250+ return emitAllocateFullZASaveBuffer (MBB, MBBI, PhysLiveRegs);
251+ return emitAllocateLazySaveBuffer (MBB, MBBI);
252+ }
220253
221254 // / Save live physical registers to virtual registers.
222255 PhysRegSave createPhysRegSave (LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
@@ -228,6 +261,8 @@ struct MachineSMEABI : public MachineFunctionPass {
228261 // / Get or create a TPIDR2 block in this function.
229262 TPIDR2State getTPIDR2Block ();
230263
264+ Register getAgnosticZABufferPtr ();
265+
231266private:
232267 // / Contains the needed ZA state (and live registers) at an instruction.
233268 struct InstInfo {
@@ -241,6 +276,7 @@ struct MachineSMEABI : public MachineFunctionPass {
241276 struct BlockInfo {
242277 ZAState FixedEntryState{ZAState::ANY};
243278 SmallVector<InstInfo> Insts;
279+ LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
244280 LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
245281 };
246282
@@ -250,6 +286,9 @@ struct MachineSMEABI : public MachineFunctionPass {
250286 SmallVector<ZAState> BundleStates;
251287 std::optional<TPIDR2State> TPIDR2Block;
252288 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
289+ Register AgnosticZABufferPtr = AArch64::NoRegister;
290+ LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
291+ bool HasFullZASaveRestore = false ;
253292 } State;
254293
255294 MachineFunction *MF = nullptr ;
@@ -261,7 +300,8 @@ struct MachineSMEABI : public MachineFunctionPass {
261300};
262301
263302void MachineSMEABI::collectNeededZAStates (SMEAttrs SMEFnAttrs) {
264- assert ((SMEFnAttrs.hasZT0State () || SMEFnAttrs.hasZAState ()) &&
303+ assert ((SMEFnAttrs.hasAgnosticZAInterface () || SMEFnAttrs.hasZT0State () ||
304+ SMEFnAttrs.hasZAState ()) &&
265305 " Expected function to have ZA/ZT0 state!" );
266306
267307 State.Blocks .resize (MF->getNumBlockIDs ());
@@ -295,6 +335,7 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
295335
296336 Block.PhysLiveRegsAtExit = GetPhysLiveRegs ();
297337 auto FirstTerminatorInsertPt = MBB.getFirstTerminator ();
338+ auto FirstNonPhiInsertPt = MBB.getFirstNonPHI ();
298339 for (MachineInstr &MI : reverse (MBB)) {
299340 MachineBasicBlock::iterator MBBI (MI);
300341 LiveUnits.stepBackward (MI);
@@ -305,7 +346,9 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
305346 // block setup.
306347 if (MI.getOpcode () == AArch64::SMEStateAllocPseudo) {
307348 State.AfterSMEProloguePt = MBBI;
349+ State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
308350 }
351+ // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
309352 auto [NeededState, InsertPt] = getZAStateBeforeInst (
310353 *TRI, MI, /* ZAOffAtReturn=*/ SMEFnAttrs.hasPrivateZAInterface ());
311354 assert ((InsertPt == MBBI ||
@@ -314,6 +357,8 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
314357 // TODO: Do something to avoid state changes where NZCV is live.
315358 if (MBBI == FirstTerminatorInsertPt)
316359 Block.PhysLiveRegsAtExit = PhysLiveRegs;
360+ if (MBBI == FirstNonPhiInsertPt)
361+ Block.PhysLiveRegsAtEntry = PhysLiveRegs;
317362 if (NeededState != ZAState::ANY)
318363 Block.Insts .push_back ({NeededState, InsertPt, PhysLiveRegs});
319364 }
@@ -380,7 +425,7 @@ void MachineSMEABI::assignBundleZAStates() {
380425 }
381426}
382427
383- void MachineSMEABI::insertStateChanges () {
428+ void MachineSMEABI::insertStateChanges (bool IsAgnosticZA ) {
384429 for (MachineBasicBlock &MBB : *MF) {
385430 const BlockInfo &Block = State.Blocks [MBB.getNumber ()];
386431 ZAState InState = State.BundleStates [Bundles->getBundle (MBB.getNumber (),
@@ -393,7 +438,7 @@ void MachineSMEABI::insertStateChanges() {
393438 for (auto &Inst : Block.Insts ) {
394439 if (CurrentState != Inst.NeededState )
395440 emitStateChange (MBB, Inst.InsertPt , CurrentState, Inst.NeededState ,
396- Inst.PhysLiveRegs );
441+ Inst.PhysLiveRegs , IsAgnosticZA );
397442 CurrentState = Inst.NeededState ;
398443 }
399444
@@ -404,7 +449,7 @@ void MachineSMEABI::insertStateChanges() {
404449 State.BundleStates [Bundles->getBundle (MBB.getNumber (), /* Out=*/ true )];
405450 if (CurrentState != OutState)
406451 emitStateChange (MBB, MBB.getFirstTerminator (), CurrentState, OutState,
407- Block.PhysLiveRegsAtExit );
452+ Block.PhysLiveRegsAtExit , IsAgnosticZA );
408453 }
409454}
410455
@@ -618,10 +663,95 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
618663 .addImm (1 );
619664}
620665
666+ Register MachineSMEABI::getAgnosticZABufferPtr () {
667+ if (State.AgnosticZABufferPtr != AArch64::NoRegister)
668+ return State.AgnosticZABufferPtr ;
669+ if (auto BufferPtr =
670+ MF->getInfo <AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer ();
671+ BufferPtr != AArch64::NoRegister)
672+ State.AgnosticZABufferPtr = BufferPtr;
673+ else
674+ State.AgnosticZABufferPtr =
675+ MF->getRegInfo ().createVirtualRegister (&AArch64::GPR64RegClass);
676+ return State.AgnosticZABufferPtr ;
677+ }
678+
679+ void MachineSMEABI::emitFullZASaveRestore (MachineBasicBlock &MBB,
680+ MachineBasicBlock::iterator MBBI,
681+ LiveRegs PhysLiveRegs, bool IsSave) {
682+ auto *TLI = Subtarget->getTargetLowering ();
683+ State.HasFullZASaveRestore = true ;
684+ DebugLoc DL = getDebugLoc (MBB, MBBI);
685+ Register BufferPtr = AArch64::X0;
686+
687+ PhysRegSave RegSave = createPhysRegSave (PhysLiveRegs, MBB, MBBI, DL);
688+
689+ // Copy the buffer pointer into X0.
690+ BuildMI (MBB, MBBI, DL, TII->get (TargetOpcode::COPY), BufferPtr)
691+ .addReg (getAgnosticZABufferPtr ());
692+
693+ // Call __arm_sme_save/__arm_sme_restore.
694+ BuildMI (MBB, MBBI, DL, TII->get (AArch64::BL))
695+ .addReg (BufferPtr, RegState::Implicit)
696+ .addExternalSymbol (TLI->getLibcallName (
697+ IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE))
698+ .addRegMask (TRI->getCallPreservedMask (
699+ *MF,
700+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
701+
702+ restorePhyRegSave (RegSave, MBB, MBBI, DL);
703+ }
704+
705+ void MachineSMEABI::emitAllocateFullZASaveBuffer (
706+ MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
707+ LiveRegs PhysLiveRegs) {
708+ auto *AFI = MF->getInfo <AArch64FunctionInfo>();
709+
710+ // Buffer already allocated in SelectionDAG.
711+ if (AFI->getEarlyAllocSMESaveBuffer ())
712+ return ;
713+
714+ DebugLoc DL = getDebugLoc (MBB, MBBI);
715+ Register BufferPtr = getAgnosticZABufferPtr ();
716+ Register BufferSize = MRI->createVirtualRegister (&AArch64::GPR64RegClass);
717+
718+ PhysRegSave RegSave = createPhysRegSave (PhysLiveRegs, MBB, MBBI, DL);
719+
720+ // Calculate the SME state size.
721+ {
722+ auto *TLI = Subtarget->getTargetLowering ();
723+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo ();
724+ BuildMI (MBB, MBBI, DL, TII->get (AArch64::BL))
725+ .addExternalSymbol (TLI->getLibcallName (RTLIB::SMEABI_SME_STATE_SIZE))
726+ .addReg (AArch64::X0, RegState::ImplicitDefine)
727+ .addRegMask (TRI->getCallPreservedMask (
728+ *MF, CallingConv::
729+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
730+ BuildMI (MBB, MBBI, DL, TII->get (TargetOpcode::COPY), BufferSize)
731+ .addReg (AArch64::X0);
732+ }
733+
734+ // Allocate a buffer object of the size given __arm_sme_state_size.
735+ {
736+ MachineFrameInfo &MFI = MF->getFrameInfo ();
737+ BuildMI (MBB, MBBI, DL, TII->get (AArch64::SUBXrx64), AArch64::SP)
738+ .addReg (AArch64::SP)
739+ .addReg (BufferSize)
740+ .addImm (AArch64_AM::getArithExtendImm (AArch64_AM::UXTX, 0 ));
741+ BuildMI (MBB, MBBI, DL, TII->get (TargetOpcode::COPY), BufferPtr)
742+ .addReg (AArch64::SP);
743+
744+ // We have just allocated a variable sized object, tell this to PEI.
745+ MFI.CreateVariableSizedObject (Align (16 ), nullptr );
746+ }
747+
748+ restorePhyRegSave (RegSave, MBB, MBBI, DL);
749+ }
750+
621751void MachineSMEABI::emitStateChange (MachineBasicBlock &MBB,
622752 MachineBasicBlock::iterator InsertPt,
623753 ZAState From, ZAState To,
624- LiveRegs PhysLiveRegs) {
754+ LiveRegs PhysLiveRegs, bool IsAgnosticZA ) {
625755
626756 // ZA not used.
627757 if (From == ZAState::ANY || To == ZAState::ANY)
@@ -653,12 +783,13 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
653783 }
654784
655785 if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
656- emitSetupLazySave (MBB, InsertPt);
786+ emitZASave (MBB, InsertPt, PhysLiveRegs, IsAgnosticZA );
657787 else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
658- emitRestoreLazySave (MBB, InsertPt, PhysLiveRegs);
788+ emitZARestore (MBB, InsertPt, PhysLiveRegs, IsAgnosticZA );
659789 else if (To == ZAState::OFF) {
660790 assert (From != ZAState::CALLER_DORMANT &&
661791 " CALLER_DORMANT to OFF should have already been handled" );
792+ assert (!IsAgnosticZA && " Should not turn ZA off in agnostic ZA function" );
662793 emitZAOff (MBB, InsertPt, /* ClearTPIDR2=*/ From == ZAState::LOCAL_SAVED);
663794 } else {
664795 dbgs () << " Error: Transition from " << getZAStateString (From) << " to "
@@ -678,7 +809,8 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
678809
679810 auto *AFI = MF.getInfo <AArch64FunctionInfo>();
680811 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs ();
681- if (!SMEFnAttrs.hasZAState () && !SMEFnAttrs.hasZT0State ())
812+ if (!SMEFnAttrs.hasZAState () && !SMEFnAttrs.hasZT0State () &&
813+ !SMEFnAttrs.hasAgnosticZAInterface ())
682814 return false ;
683815
684816 assert (MF.getRegInfo ().isSSA () && " Expected to be run on SSA form!" );
@@ -692,20 +824,27 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
692824 TRI = Subtarget->getRegisterInfo ();
693825 MRI = &MF.getRegInfo ();
694826
827+ bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface ();
828+
695829 collectNeededZAStates (SMEFnAttrs);
696830 assignBundleZAStates ();
697- insertStateChanges ();
831+ insertStateChanges (/* IsAgnosticZA= */ IsAgnosticZA );
698832
699833 // Allocate save buffer (if needed).
700- if (State.TPIDR2Block ) {
834+ if (State.HasFullZASaveRestore || State. TPIDR2Block ) {
701835 if (State.AfterSMEProloguePt ) {
702836 // Note: With inline stack probes the AfterSMEProloguePt may not be in the
703837 // entry block (due to the probing loop).
704- emitAllocateLazySaveBuffer (*(*State.AfterSMEProloguePt )->getParent (),
705- *State.AfterSMEProloguePt );
838+ emitAllocateZASaveBuffer (*(*State.AfterSMEProloguePt )->getParent (),
839+ *State.AfterSMEProloguePt ,
840+ State.PhysLiveRegsAfterSMEPrologue ,
841+ /* IsAgnosticZA=*/ IsAgnosticZA);
706842 } else {
707843 MachineBasicBlock &EntryBlock = MF.front ();
708- emitAllocateLazySaveBuffer (EntryBlock, EntryBlock.getFirstNonPHI ());
844+ emitAllocateZASaveBuffer (
845+ EntryBlock, EntryBlock.getFirstNonPHI (),
846+ State.Blocks [EntryBlock.getNumber ()].PhysLiveRegsAtEntry ,
847+ /* IsAgnosticZA=*/ IsAgnosticZA);
709848 }
710849 }
711850
0 commit comments