77// ===----------------------------------------------------------------------===//
88//
99// This pass implements the SME ABI requirements for ZA state. This includes
10- // implementing the lazy ZA state save schemes around calls.
10+ // implementing the lazy (and agnostic) ZA state save schemes around calls.
1111//
1212// ===----------------------------------------------------------------------===//
1313
@@ -128,7 +128,7 @@ struct MachineSMEABI : public MachineFunctionPass {
128128
129129 void collectNeededZAStates (MachineFunction &MF, SMEAttrs);
130130 void pickBundleZAStates (MachineFunction &MF);
131- void insertStateChanges (MachineFunction &MF);
131+ void insertStateChanges (MachineFunction &MF, bool IsAgnosticZA );
132132
133133 // Emission routines for private and shared ZA functions (using lazy saves).
134134 void emitNewZAPrologue (MachineBasicBlock &MBB,
@@ -143,11 +143,46 @@ struct MachineSMEABI : public MachineFunctionPass {
143143 void emitZAOff (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
144144 bool ClearTPIDR2);
145145
146+ // Emission routines for agnostic ZA functions.
147+ void emitSetupFullZASave (MachineBasicBlock &MBB,
148+ MachineBasicBlock::iterator MBBI,
149+ LiveRegs PhysLiveRegs);
150+ void emitFullZASaveRestore (MachineBasicBlock &MBB,
151+ MachineBasicBlock::iterator MBBI,
152+ LiveRegs PhysLiveRegs, bool IsSave);
153+ void emitAllocateFullZASaveBuffer (MachineBasicBlock &MBB,
154+ MachineBasicBlock::iterator MBBI,
155+ LiveRegs PhysLiveRegs);
156+
146157 void emitStateChange (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
147- ZAState From, ZAState To, LiveRegs PhysLiveRegs);
158+ ZAState From, ZAState To, LiveRegs PhysLiveRegs,
159+ bool IsAgnosticZA);
160+
161+ // Helpers for switching between lazy/full ZA save/restore routines.
162+ void emitZASave (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
163+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
164+ if (IsAgnosticZA)
165+ return emitFullZASaveRestore (MBB, MBBI, PhysLiveRegs, /* IsSave=*/ true );
166+ return emitSetupLazySave (MBB, MBBI);
167+ }
168+ void emitZARestore (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
169+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
170+ if (IsAgnosticZA)
171+ return emitFullZASaveRestore (MBB, MBBI, PhysLiveRegs, /* IsSave=*/ false );
172+ return emitRestoreLazySave (MBB, MBBI, PhysLiveRegs);
173+ }
174+ void emitAllocateZASaveBuffer (MachineBasicBlock &MBB,
175+ MachineBasicBlock::iterator MBBI,
176+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
177+ if (IsAgnosticZA)
178+ return emitAllocateFullZASaveBuffer (MBB, MBBI, PhysLiveRegs);
179+ return emitAllocateLazySaveBuffer (MBB, MBBI);
180+ }
148181
149182 TPIDR2State getTPIDR2Block (MachineFunction &MF);
150183
184+ Register getAgnosticZABufferPtr (MachineFunction &MF);
185+
151186private:
152187 struct InstInfo {
153188 ZAState NeededState{ZAState::ANY};
@@ -158,6 +193,7 @@ struct MachineSMEABI : public MachineFunctionPass {
158193 struct BlockInfo {
159194 ZAState FixedEntryState{ZAState::ANY};
160195 SmallVector<InstInfo> Insts;
196+ LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
161197 LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
162198 };
163199
@@ -167,6 +203,9 @@ struct MachineSMEABI : public MachineFunctionPass {
167203 SmallVector<ZAState> BundleStates;
168204 std::optional<TPIDR2State> TPIDR2Block;
169205 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
206+ Register AgnosticZABufferPtr = AArch64::NoRegister;
207+ LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
208+ bool HasFullZASaveRestore = false ;
170209 } State;
171210
172211 EdgeBundles *Bundles = nullptr ;
@@ -175,7 +214,8 @@ struct MachineSMEABI : public MachineFunctionPass {
175214void MachineSMEABI::collectNeededZAStates (MachineFunction &MF,
176215 SMEAttrs SMEFnAttrs) {
177216 const TargetRegisterInfo &TRI = *MF.getSubtarget ().getRegisterInfo ();
178- assert ((SMEFnAttrs.hasZT0State () || SMEFnAttrs.hasZAState ()) &&
217+ assert ((SMEFnAttrs.hasAgnosticZAInterface () || SMEFnAttrs.hasZT0State () ||
218+ SMEFnAttrs.hasZAState ()) &&
179219 " Expected function to have ZA/ZT0 state!" );
180220
181221 State.Blocks .resize (MF.getNumBlockIDs ());
@@ -209,6 +249,7 @@ void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
209249
210250 Block.PhysLiveRegsAtExit = GetPhysLiveRegs ();
211251 auto FirstTerminatorInsertPt = MBB.getFirstTerminator ();
252+ auto FirstNonPhiInsertPt = MBB.getFirstNonPHI ();
212253 for (MachineInstr &MI : reverse (MBB)) {
213254 MachineBasicBlock::iterator MBBI (MI);
214255 LiveUnits.stepBackward (MI);
@@ -219,15 +260,20 @@ void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
219260 // block setup.
220261 if (MI.getOpcode () == AArch64::SMEStateAllocPseudo) {
221262 State.AfterSMEProloguePt = MBBI;
263+ State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
222264 }
265+ // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
223266 auto [NeededState, InsertPt] = getInstNeededZAState (
224- TRI, MI, /* ZALiveAtReturn=*/ SMEFnAttrs.hasSharedZAInterface ());
267+ TRI, MI, /* ZALiveAtReturn=*/ SMEFnAttrs.hasSharedZAInterface () ||
268+ SMEFnAttrs.hasAgnosticZAInterface ());
225269 assert ((InsertPt == MBBI ||
226270 InsertPt->getOpcode () == AArch64::ADJCALLSTACKDOWN) &&
227271 " Unexpected state change insertion point!" );
228272 // TODO: Do something to avoid state changes where NZCV is live.
229273 if (MBBI == FirstTerminatorInsertPt)
230274 Block.PhysLiveRegsAtExit = PhysLiveRegs;
275+ if (MBBI == FirstNonPhiInsertPt)
276+ Block.PhysLiveRegsAtEntry = PhysLiveRegs;
231277 if (NeededState != ZAState::ANY)
232278 Block.Insts .push_back ({NeededState, InsertPt, PhysLiveRegs});
233279 }
@@ -294,7 +340,7 @@ void MachineSMEABI::pickBundleZAStates(MachineFunction &MF) {
294340 }
295341}
296342
297- void MachineSMEABI::insertStateChanges (MachineFunction &MF) {
343+ void MachineSMEABI::insertStateChanges (MachineFunction &MF, bool IsAgnosticZA ) {
298344 for (MachineBasicBlock &MBB : MF) {
299345 BlockInfo &Block = State.Blocks [MBB.getNumber ()];
300346 ZAState InState =
@@ -309,7 +355,7 @@ void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
309355 for (auto &Inst : Block.Insts ) {
310356 if (CurrentState != Inst.NeededState )
311357 emitStateChange (MBB, Inst.InsertPt , CurrentState, Inst.NeededState ,
312- Inst.PhysLiveRegs );
358+ Inst.PhysLiveRegs , IsAgnosticZA );
313359 CurrentState = Inst.NeededState ;
314360 }
315361
@@ -318,7 +364,7 @@ void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
318364
319365 if (CurrentState != OutState)
320366 emitStateChange (MBB, MBB.getFirstTerminator (), CurrentState, OutState,
321- Block.PhysLiveRegsAtExit );
367+ Block.PhysLiveRegsAtExit , IsAgnosticZA );
322368 }
323369}
324370
@@ -571,10 +617,98 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
571617 emitZeroZA (TII, DL, MBB, MBBI, /* Mask=*/ 0b11111111 );
572618}
573619
620+ Register MachineSMEABI::getAgnosticZABufferPtr (MachineFunction &MF) {
621+ if (State.AgnosticZABufferPtr != AArch64::NoRegister)
622+ return State.AgnosticZABufferPtr ;
623+ if (auto BufferPtr =
624+ MF.getInfo <AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer ();
625+ BufferPtr != AArch64::NoRegister)
626+ State.AgnosticZABufferPtr = BufferPtr;
627+ else
628+ State.AgnosticZABufferPtr =
629+ MF.getRegInfo ().createVirtualRegister (&AArch64::GPR64RegClass);
630+ return State.AgnosticZABufferPtr ;
631+ }
632+
633+ void MachineSMEABI::emitFullZASaveRestore (MachineBasicBlock &MBB,
634+ MachineBasicBlock::iterator MBBI,
635+ LiveRegs PhysLiveRegs, bool IsSave) {
636+ MachineFunction &MF = *MBB.getParent ();
637+ auto &Subtarget = MF.getSubtarget <AArch64Subtarget>();
638+ const AArch64RegisterInfo &TRI = *Subtarget.getRegisterInfo ();
639+ const TargetInstrInfo &TII = *Subtarget.getInstrInfo ();
640+ MachineRegisterInfo &MRI = MF.getRegInfo ();
641+
642+ State.HasFullZASaveRestore = true ;
643+ DebugLoc DL = getDebugLoc (MBB, MBBI);
644+ Register BufferPtr = AArch64::X0;
645+
646+ ScopedPhysRegSave ScopedPhysRegSave (MRI, TII, DL, MBB, MBBI, PhysLiveRegs);
647+
648+ // Copy the buffer pointer into X0.
649+ BuildMI (MBB, MBBI, DL, TII.get (TargetOpcode::COPY), BufferPtr)
650+ .addReg (getAgnosticZABufferPtr (MF));
651+
652+ // Call __arm_sme_save/__arm_sme_restore.
653+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::BL))
654+ .addReg (BufferPtr, RegState::Implicit)
655+ .addExternalSymbol (IsSave ? " __arm_sme_save" : " __arm_sme_restore" )
656+ .addRegMask (TRI.getCallPreservedMask (
657+ MF,
658+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
659+ }
660+
661+ void MachineSMEABI::emitAllocateFullZASaveBuffer (
662+ MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
663+ LiveRegs PhysLiveRegs) {
664+ MachineFunction &MF = *MBB.getParent ();
665+ MachineFrameInfo &MFI = MF.getFrameInfo ();
666+ auto &Subtarget = MF.getSubtarget <AArch64Subtarget>();
667+ const TargetInstrInfo &TII = *Subtarget.getInstrInfo ();
668+ MachineRegisterInfo &MRI = MF.getRegInfo ();
669+ auto *AFI = MF.getInfo <AArch64FunctionInfo>();
670+
671+ // Buffer already allocated in SelectionDAG.
672+ if (AFI->getEarlyAllocSMESaveBuffer ())
673+ return ;
674+
675+ DebugLoc DL = getDebugLoc (MBB, MBBI);
676+ Register BufferPtr = getAgnosticZABufferPtr (MF);
677+ Register BufferSize = MRI.createVirtualRegister (&AArch64::GPR64RegClass);
678+
679+ ScopedPhysRegSave ScopedPhysRegSave (MRI, TII, DL, MBB, MBBI, PhysLiveRegs);
680+
681+ // Calculate the SME state size.
682+ {
683+ const AArch64RegisterInfo *TRI = Subtarget.getRegisterInfo ();
684+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::BL))
685+ .addExternalSymbol (" __arm_sme_state_size" )
686+ .addReg (AArch64::X0, RegState::ImplicitDefine)
687+ .addRegMask (TRI->getCallPreservedMask (
688+ MF, CallingConv::
689+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
690+ BuildMI (MBB, MBBI, DL, TII.get (TargetOpcode::COPY), BufferSize)
691+ .addReg (AArch64::X0);
692+ }
693+
694+ // Allocate a buffer object of the size given __arm_sme_state_size.
695+ {
696+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::SUBXrx64), AArch64::SP)
697+ .addReg (AArch64::SP)
698+ .addReg (BufferSize)
699+ .addImm (AArch64_AM::getArithExtendImm (AArch64_AM::UXTX, 0 ));
700+ BuildMI (MBB, MBBI, DL, TII.get (TargetOpcode::COPY), BufferPtr)
701+ .addReg (AArch64::SP);
702+
703+ // We have just allocated a variable sized object, tell this to PEI.
704+ MFI.CreateVariableSizedObject (Align (16 ), nullptr );
705+ }
706+ }
707+
574708void MachineSMEABI::emitStateChange (MachineBasicBlock &MBB,
575709 MachineBasicBlock::iterator InsertPt,
576710 ZAState From, ZAState To,
577- LiveRegs PhysLiveRegs) {
711+ LiveRegs PhysLiveRegs, bool IsAgnosticZA ) {
578712
579713 // ZA not used.
580714 if (From == ZAState::ANY || To == ZAState::ANY)
@@ -601,10 +735,11 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
601735 }
602736
603737 if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
604- emitSetupLazySave (MBB, InsertPt);
738+ emitZASave (MBB, InsertPt, PhysLiveRegs, IsAgnosticZA );
605739 else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
606- emitRestoreLazySave (MBB, InsertPt, PhysLiveRegs);
740+ emitZARestore (MBB, InsertPt, PhysLiveRegs, IsAgnosticZA );
607741 else if (To == ZAState::OFF) {
742+ assert (!IsAgnosticZA && " Should not turn ZA off in agnostic ZA function" );
608743 // If we're exiting from the CALLER_DORMANT state that means this new ZA
609744 // function did not touch ZA (so ZA was never turned on).
610745 if (From != ZAState::CALLER_DORMANT)
@@ -627,7 +762,8 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
627762
628763 auto *AFI = MF.getInfo <AArch64FunctionInfo>();
629764 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs ();
630- if (!SMEFnAttrs.hasZAState () && !SMEFnAttrs.hasZT0State ())
765+ if (!SMEFnAttrs.hasZAState () && !SMEFnAttrs.hasZT0State () &&
766+ !SMEFnAttrs.hasAgnosticZAInterface ())
631767 return false ;
632768
633769 assert (MF.getRegInfo ().isSSA () && " Expected to be run on SSA form!" );
@@ -636,20 +772,27 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
636772 State = PassState{};
637773 Bundles = &getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles ();
638774
775+ bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface ();
776+
639777 collectNeededZAStates (MF, SMEFnAttrs);
640778 pickBundleZAStates (MF);
641- insertStateChanges (MF);
779+ insertStateChanges (MF, /* IsAgnosticZA= */ IsAgnosticZA );
642780
643781 // Allocate save buffer (if needed).
644- if (State.TPIDR2Block .has_value ()) {
782+ if (State.HasFullZASaveRestore || State. TPIDR2Block .has_value ()) {
645783 if (State.AfterSMEProloguePt ) {
646784 // Note: With inline stack probes the AfterSMEProloguePt may not be in the
647785 // entry block (due to the probing loop).
648- emitAllocateLazySaveBuffer (*(*State.AfterSMEProloguePt )->getParent (),
649- *State.AfterSMEProloguePt );
786+ emitAllocateZASaveBuffer (*(*State.AfterSMEProloguePt )->getParent (),
787+ *State.AfterSMEProloguePt ,
788+ State.PhysLiveRegsAfterSMEPrologue ,
789+ /* IsAgnosticZA=*/ IsAgnosticZA);
650790 } else {
651791 MachineBasicBlock &EntryBlock = MF.front ();
652- emitAllocateLazySaveBuffer (EntryBlock, EntryBlock.getFirstNonPHI ());
792+ emitAllocateZASaveBuffer (
793+ EntryBlock, EntryBlock.getFirstNonPHI (),
794+ State.Blocks [EntryBlock.getNumber ()].PhysLiveRegsAtEntry ,
795+ /* IsAgnosticZA=*/ IsAgnosticZA);
653796 }
654797 }
655798
0 commit comments