77// ===----------------------------------------------------------------------===//
88//
99// This pass implements the SME ABI requirements for ZA state. This includes
10- // implementing the lazy ZA state save schemes around calls.
10+ // implementing the lazy (and agnostic) ZA state save schemes around calls.
1111//
1212// ===----------------------------------------------------------------------===//
1313//
@@ -215,9 +215,44 @@ struct MachineSMEABI : public MachineFunctionPass {
215215 void emitZAOff (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
216216 bool ClearTPIDR2);
217217
218+ // Emission routines for agnostic ZA functions.
219+ void emitSetupFullZASave (MachineBasicBlock &MBB,
220+ MachineBasicBlock::iterator MBBI,
221+ LiveRegs PhysLiveRegs);
222+ // Emit a "full" ZA save or restore. It is "full" in the sense that this
223+ // function will emit a call to __arm_sme_save or __arm_sme_restore, which
224+ // handles saving and restoring both ZA and ZT0.
225+ void emitFullZASaveRestore (MachineBasicBlock &MBB,
226+ MachineBasicBlock::iterator MBBI,
227+ LiveRegs PhysLiveRegs, bool IsSave);
228+ void emitAllocateFullZASaveBuffer (MachineBasicBlock &MBB,
229+ MachineBasicBlock::iterator MBBI,
230+ LiveRegs PhysLiveRegs);
231+
218232 void emitStateChange (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
219233 ZAState From, ZAState To, LiveRegs PhysLiveRegs);
220234
235+ // Helpers for switching between lazy/full ZA save/restore routines.
236+ void emitZASave (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
237+ LiveRegs PhysLiveRegs) {
238+ if (AFI->getSMEFnAttrs ().hasAgnosticZAInterface ())
239+ return emitFullZASaveRestore (MBB, MBBI, PhysLiveRegs, /* IsSave=*/ true );
240+ return emitSetupLazySave (MBB, MBBI);
241+ }
242+ void emitZARestore (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
243+ LiveRegs PhysLiveRegs) {
244+ if (AFI->getSMEFnAttrs ().hasAgnosticZAInterface ())
245+ return emitFullZASaveRestore (MBB, MBBI, PhysLiveRegs, /* IsSave=*/ false );
246+ return emitRestoreLazySave (MBB, MBBI, PhysLiveRegs);
247+ }
248+ void emitAllocateZASaveBuffer (MachineBasicBlock &MBB,
249+ MachineBasicBlock::iterator MBBI,
250+ LiveRegs PhysLiveRegs) {
251+ if (AFI->getSMEFnAttrs ().hasAgnosticZAInterface ())
252+ return emitAllocateFullZASaveBuffer (MBB, MBBI, PhysLiveRegs);
253+ return emitAllocateLazySaveBuffer (MBB, MBBI);
254+ }
255+
221256 // / Save live physical registers to virtual registers.
222257 PhysRegSave createPhysRegSave (LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
223258 MachineBasicBlock::iterator MBBI, DebugLoc DL);
@@ -228,6 +263,8 @@ struct MachineSMEABI : public MachineFunctionPass {
228263 // / Get or create a TPIDR2 block in this function.
229264 TPIDR2State getTPIDR2Block ();
230265
266+ Register getAgnosticZABufferPtr ();
267+
231268private:
232269 // / Contains the needed ZA state (and live registers) at an instruction.
233270 struct InstInfo {
@@ -241,6 +278,7 @@ struct MachineSMEABI : public MachineFunctionPass {
241278 struct BlockInfo {
242279 ZAState FixedEntryState{ZAState::ANY};
243280 SmallVector<InstInfo> Insts;
281+ LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
244282 LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
245283 };
246284
@@ -250,18 +288,22 @@ struct MachineSMEABI : public MachineFunctionPass {
250288 SmallVector<ZAState> BundleStates;
251289 std::optional<TPIDR2State> TPIDR2Block;
252290 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
291+ Register AgnosticZABufferPtr = AArch64::NoRegister;
292+ LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
253293 } State;
254294
255295 MachineFunction *MF = nullptr ;
256296 EdgeBundles *Bundles = nullptr ;
257297 const AArch64Subtarget *Subtarget = nullptr ;
258298 const AArch64RegisterInfo *TRI = nullptr ;
299+ const AArch64FunctionInfo *AFI = nullptr ;
259300 const TargetInstrInfo *TII = nullptr ;
260301 MachineRegisterInfo *MRI = nullptr ;
261302};
262303
263304void MachineSMEABI::collectNeededZAStates (SMEAttrs SMEFnAttrs) {
264- assert ((SMEFnAttrs.hasZT0State () || SMEFnAttrs.hasZAState ()) &&
305+ assert ((SMEFnAttrs.hasAgnosticZAInterface () || SMEFnAttrs.hasZT0State () ||
306+ SMEFnAttrs.hasZAState ()) &&
265307 " Expected function to have ZA/ZT0 state!" );
266308
267309 State.Blocks .resize (MF->getNumBlockIDs ());
@@ -295,6 +337,7 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
295337
296338 Block.PhysLiveRegsAtExit = GetPhysLiveRegs ();
297339 auto FirstTerminatorInsertPt = MBB.getFirstTerminator ();
340+ auto FirstNonPhiInsertPt = MBB.getFirstNonPHI ();
298341 for (MachineInstr &MI : reverse (MBB)) {
299342 MachineBasicBlock::iterator MBBI (MI);
300343 LiveUnits.stepBackward (MI);
@@ -303,8 +346,11 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
303346 // buffer was allocated in SelectionDAG. It marks the end of the
304347 // allocation -- which is a safe point for this pass to insert any TPIDR2
305348 // block setup.
306- if (MI.getOpcode () == AArch64::SMEStateAllocPseudo)
349+ if (MI.getOpcode () == AArch64::SMEStateAllocPseudo) {
307350 State.AfterSMEProloguePt = MBBI;
351+ State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
352+ }
353+ // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
308354 auto [NeededState, InsertPt] = getZAStateBeforeInst (
309355 *TRI, MI, /* ZAOffAtReturn=*/ SMEFnAttrs.hasPrivateZAInterface ());
310356 assert ((InsertPt == MBBI ||
@@ -313,6 +359,8 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
313359 // TODO: Do something to avoid state changes where NZCV is live.
314360 if (MBBI == FirstTerminatorInsertPt)
315361 Block.PhysLiveRegsAtExit = PhysLiveRegs;
362+ if (MBBI == FirstNonPhiInsertPt)
363+ Block.PhysLiveRegsAtEntry = PhysLiveRegs;
316364 if (NeededState != ZAState::ANY)
317365 Block.Insts .push_back ({NeededState, InsertPt, PhysLiveRegs});
318366 }
@@ -536,8 +584,6 @@ void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
536584void MachineSMEABI::emitAllocateLazySaveBuffer (
537585 MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) {
538586 MachineFrameInfo &MFI = MF->getFrameInfo ();
539- auto *AFI = MF->getInfo <AArch64FunctionInfo>();
540-
541587 DebugLoc DL = getDebugLoc (MBB, MBBI);
542588 Register SP = MRI->createVirtualRegister (&AArch64::GPR64RegClass);
543589 Register SVL = MRI->createVirtualRegister (&AArch64::GPR64RegClass);
@@ -601,8 +647,7 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
601647 .addImm (AArch64SysReg::TPIDR2_EL0);
602648 // If TPIDR2_EL0 is non-zero, commit the lazy save.
603649 // NOTE: Functions that only use ZT0 don't need to zero ZA.
604- bool ZeroZA =
605- MF->getInfo <AArch64FunctionInfo>()->getSMEFnAttrs ().hasZAState ();
650+ bool ZeroZA = AFI->getSMEFnAttrs ().hasZAState ();
606651 auto CommitZASave =
607652 BuildMI (MBB, MBBI, DL, TII->get (AArch64::CommitZASavePseudo))
608653 .addReg (TPIDR2EL0)
@@ -617,6 +662,86 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
617662 .addImm (1 );
618663}
619664
665+ Register MachineSMEABI::getAgnosticZABufferPtr () {
666+ if (State.AgnosticZABufferPtr != AArch64::NoRegister)
667+ return State.AgnosticZABufferPtr ;
668+ Register BufferPtr = AFI->getEarlyAllocSMESaveBuffer ();
669+ State.AgnosticZABufferPtr =
670+ BufferPtr != AArch64::NoRegister
671+ ? BufferPtr
672+ : MF->getRegInfo ().createVirtualRegister (&AArch64::GPR64RegClass);
673+ return State.AgnosticZABufferPtr ;
674+ }
675+
676+ void MachineSMEABI::emitFullZASaveRestore (MachineBasicBlock &MBB,
677+ MachineBasicBlock::iterator MBBI,
678+ LiveRegs PhysLiveRegs, bool IsSave) {
679+ auto *TLI = Subtarget->getTargetLowering ();
680+ DebugLoc DL = getDebugLoc (MBB, MBBI);
681+ Register BufferPtr = AArch64::X0;
682+
683+ PhysRegSave RegSave = createPhysRegSave (PhysLiveRegs, MBB, MBBI, DL);
684+
685+ // Copy the buffer pointer into X0.
686+ BuildMI (MBB, MBBI, DL, TII->get (TargetOpcode::COPY), BufferPtr)
687+ .addReg (getAgnosticZABufferPtr ());
688+
689+ // Call __arm_sme_save/__arm_sme_restore.
690+ BuildMI (MBB, MBBI, DL, TII->get (AArch64::BL))
691+ .addReg (BufferPtr, RegState::Implicit)
692+ .addExternalSymbol (TLI->getLibcallName (
693+ IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE))
694+ .addRegMask (TRI->getCallPreservedMask (
695+ *MF,
696+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
697+
698+ restorePhyRegSave (RegSave, MBB, MBBI, DL);
699+ }
700+
701+ void MachineSMEABI::emitAllocateFullZASaveBuffer (
702+ MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
703+ LiveRegs PhysLiveRegs) {
704+ // Buffer already allocated in SelectionDAG.
705+ if (AFI->getEarlyAllocSMESaveBuffer ())
706+ return ;
707+
708+ DebugLoc DL = getDebugLoc (MBB, MBBI);
709+ Register BufferPtr = getAgnosticZABufferPtr ();
710+ Register BufferSize = MRI->createVirtualRegister (&AArch64::GPR64RegClass);
711+
712+ PhysRegSave RegSave = createPhysRegSave (PhysLiveRegs, MBB, MBBI, DL);
713+
714+ // Calculate the SME state size.
715+ {
716+ auto *TLI = Subtarget->getTargetLowering ();
717+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo ();
718+ BuildMI (MBB, MBBI, DL, TII->get (AArch64::BL))
719+ .addExternalSymbol (TLI->getLibcallName (RTLIB::SMEABI_SME_STATE_SIZE))
720+ .addReg (AArch64::X0, RegState::ImplicitDefine)
721+ .addRegMask (TRI->getCallPreservedMask (
722+ *MF, CallingConv::
723+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
724+ BuildMI (MBB, MBBI, DL, TII->get (TargetOpcode::COPY), BufferSize)
725+ .addReg (AArch64::X0);
726+ }
727+
728+ // Allocate a buffer object of the size given __arm_sme_state_size.
729+ {
730+ MachineFrameInfo &MFI = MF->getFrameInfo ();
731+ BuildMI (MBB, MBBI, DL, TII->get (AArch64::SUBXrx64), AArch64::SP)
732+ .addReg (AArch64::SP)
733+ .addReg (BufferSize)
734+ .addImm (AArch64_AM::getArithExtendImm (AArch64_AM::UXTX, 0 ));
735+ BuildMI (MBB, MBBI, DL, TII->get (TargetOpcode::COPY), BufferPtr)
736+ .addReg (AArch64::SP);
737+
738+ // We have just allocated a variable sized object, tell this to PEI.
739+ MFI.CreateVariableSizedObject (Align (16 ), nullptr );
740+ }
741+
742+ restorePhyRegSave (RegSave, MBB, MBBI, DL);
743+ }
744+
620745void MachineSMEABI::emitStateChange (MachineBasicBlock &MBB,
621746 MachineBasicBlock::iterator InsertPt,
622747 ZAState From, ZAState To,
@@ -634,10 +759,7 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
634759 // TODO: Avoid setting up the save buffer if there's no transition to
635760 // LOCAL_SAVED.
636761 if (From == ZAState::CALLER_DORMANT) {
637- assert (MBB.getParent ()
638- ->getInfo <AArch64FunctionInfo>()
639- ->getSMEFnAttrs ()
640- .hasPrivateZAInterface () &&
762+ assert (AFI->getSMEFnAttrs ().hasPrivateZAInterface () &&
641763 " CALLER_DORMANT state requires private ZA interface" );
642764 assert (&MBB == &MBB.getParent ()->front () &&
643765 " CALLER_DORMANT state only valid in entry block" );
@@ -652,12 +774,14 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
652774 }
653775
654776 if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
655- emitSetupLazySave (MBB, InsertPt);
777+ emitZASave (MBB, InsertPt, PhysLiveRegs );
656778 else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
657- emitRestoreLazySave (MBB, InsertPt, PhysLiveRegs);
779+ emitZARestore (MBB, InsertPt, PhysLiveRegs);
658780 else if (To == ZAState::OFF) {
659781 assert (From != ZAState::CALLER_DORMANT &&
660782 " CALLER_DORMANT to OFF should have already been handled" );
783+ assert (!AFI->getSMEFnAttrs ().hasAgnosticZAInterface () &&
784+ " Should not turn ZA off in agnostic ZA function" );
661785 emitZAOff (MBB, InsertPt, /* ClearTPIDR2=*/ From == ZAState::LOCAL_SAVED);
662786 } else {
663787 dbgs () << " Error: Transition from " << getZAStateString (From) << " to "
@@ -675,9 +799,10 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
675799 if (!MF.getSubtarget <AArch64Subtarget>().hasSME ())
676800 return false ;
677801
678- auto * AFI = MF.getInfo <AArch64FunctionInfo>();
802+ AFI = MF.getInfo <AArch64FunctionInfo>();
679803 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs ();
680- if (!SMEFnAttrs.hasZAState () && !SMEFnAttrs.hasZT0State ())
804+ if (!SMEFnAttrs.hasZAState () && !SMEFnAttrs.hasZT0State () &&
805+ !SMEFnAttrs.hasAgnosticZAInterface ())
681806 return false ;
682807
683808 assert (MF.getRegInfo ().isSSA () && " Expected to be run on SSA form!" );
@@ -696,15 +821,18 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
696821 insertStateChanges ();
697822
698823 // Allocate save buffer (if needed).
699- if (State.TPIDR2Block ) {
824+ if (State.AgnosticZABufferPtr != AArch64::NoRegister || State. TPIDR2Block ) {
700825 if (State.AfterSMEProloguePt ) {
701826 // Note: With inline stack probes the AfterSMEProloguePt may not be in the
702827 // entry block (due to the probing loop).
703- emitAllocateLazySaveBuffer (*(*State.AfterSMEProloguePt )->getParent (),
704- *State.AfterSMEProloguePt );
828+ emitAllocateZASaveBuffer (*(*State.AfterSMEProloguePt )->getParent (),
829+ *State.AfterSMEProloguePt ,
830+ State.PhysLiveRegsAfterSMEPrologue );
705831 } else {
706832 MachineBasicBlock &EntryBlock = MF.front ();
707- emitAllocateLazySaveBuffer (EntryBlock, EntryBlock.getFirstNonPHI ());
833+ emitAllocateZASaveBuffer (
834+ EntryBlock, EntryBlock.getFirstNonPHI (),
835+ State.Blocks [EntryBlock.getNumber ()].PhysLiveRegsAtEntry );
708836 }
709837 }
710838
0 commit comments