7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
9
// This pass implements the SME ABI requirements for ZA state. This includes
10
- // implementing the lazy ZA state save schemes around calls.
10
+ // implementing the lazy (and agnostic) ZA state save schemes around calls.
11
11
//
12
12
// ===----------------------------------------------------------------------===//
13
13
//
@@ -215,9 +215,44 @@ struct MachineSMEABI : public MachineFunctionPass {
215
215
void emitZAOff (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
216
216
bool ClearTPIDR2);
217
217
218
+ // Emission routines for agnostic ZA functions.
219
+ void emitSetupFullZASave (MachineBasicBlock &MBB,
220
+ MachineBasicBlock::iterator MBBI,
221
+ LiveRegs PhysLiveRegs);
222
+ // Emit a "full" ZA save or restore. It is "full" in the sense that this
223
+ // function will emit a call to __arm_sme_save or __arm_sme_restore, which
224
+ // handles saving and restoring both ZA and ZT0.
225
+ void emitFullZASaveRestore (MachineBasicBlock &MBB,
226
+ MachineBasicBlock::iterator MBBI,
227
+ LiveRegs PhysLiveRegs, bool IsSave);
228
+ void emitAllocateFullZASaveBuffer (MachineBasicBlock &MBB,
229
+ MachineBasicBlock::iterator MBBI,
230
+ LiveRegs PhysLiveRegs);
231
+
218
232
void emitStateChange (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
219
233
ZAState From, ZAState To, LiveRegs PhysLiveRegs);
220
234
235
+ // Helpers for switching between lazy/full ZA save/restore routines.
236
+ void emitZASave (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
237
+ LiveRegs PhysLiveRegs) {
238
+ if (AFI->getSMEFnAttrs ().hasAgnosticZAInterface ())
239
+ return emitFullZASaveRestore (MBB, MBBI, PhysLiveRegs, /* IsSave=*/ true );
240
+ return emitSetupLazySave (MBB, MBBI);
241
+ }
242
+ void emitZARestore (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
243
+ LiveRegs PhysLiveRegs) {
244
+ if (AFI->getSMEFnAttrs ().hasAgnosticZAInterface ())
245
+ return emitFullZASaveRestore (MBB, MBBI, PhysLiveRegs, /* IsSave=*/ false );
246
+ return emitRestoreLazySave (MBB, MBBI, PhysLiveRegs);
247
+ }
248
+ void emitAllocateZASaveBuffer (MachineBasicBlock &MBB,
249
+ MachineBasicBlock::iterator MBBI,
250
+ LiveRegs PhysLiveRegs) {
251
+ if (AFI->getSMEFnAttrs ().hasAgnosticZAInterface ())
252
+ return emitAllocateFullZASaveBuffer (MBB, MBBI, PhysLiveRegs);
253
+ return emitAllocateLazySaveBuffer (MBB, MBBI);
254
+ }
255
+
221
256
// / Save live physical registers to virtual registers.
222
257
PhysRegSave createPhysRegSave (LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
223
258
MachineBasicBlock::iterator MBBI, DebugLoc DL);
@@ -228,6 +263,8 @@ struct MachineSMEABI : public MachineFunctionPass {
228
263
// / Get or create a TPIDR2 block in this function.
229
264
TPIDR2State getTPIDR2Block ();
230
265
266
+ Register getAgnosticZABufferPtr ();
267
+
231
268
private:
232
269
// / Contains the needed ZA state (and live registers) at an instruction.
233
270
struct InstInfo {
@@ -241,6 +278,7 @@ struct MachineSMEABI : public MachineFunctionPass {
241
278
struct BlockInfo {
242
279
ZAState FixedEntryState{ZAState::ANY};
243
280
SmallVector<InstInfo> Insts;
281
+ LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
244
282
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
245
283
};
246
284
@@ -250,18 +288,22 @@ struct MachineSMEABI : public MachineFunctionPass {
250
288
SmallVector<ZAState> BundleStates;
251
289
std::optional<TPIDR2State> TPIDR2Block;
252
290
std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
291
+ Register AgnosticZABufferPtr = AArch64::NoRegister;
292
+ LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
253
293
} State;
254
294
255
295
MachineFunction *MF = nullptr ;
256
296
EdgeBundles *Bundles = nullptr ;
257
297
const AArch64Subtarget *Subtarget = nullptr ;
258
298
const AArch64RegisterInfo *TRI = nullptr ;
299
+ const AArch64FunctionInfo *AFI = nullptr ;
259
300
const TargetInstrInfo *TII = nullptr ;
260
301
MachineRegisterInfo *MRI = nullptr ;
261
302
};
262
303
263
304
void MachineSMEABI::collectNeededZAStates (SMEAttrs SMEFnAttrs) {
264
- assert ((SMEFnAttrs.hasZT0State () || SMEFnAttrs.hasZAState ()) &&
305
+ assert ((SMEFnAttrs.hasAgnosticZAInterface () || SMEFnAttrs.hasZT0State () ||
306
+ SMEFnAttrs.hasZAState ()) &&
265
307
" Expected function to have ZA/ZT0 state!" );
266
308
267
309
State.Blocks .resize (MF->getNumBlockIDs ());
@@ -295,6 +337,7 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
295
337
296
338
Block.PhysLiveRegsAtExit = GetPhysLiveRegs ();
297
339
auto FirstTerminatorInsertPt = MBB.getFirstTerminator ();
340
+ auto FirstNonPhiInsertPt = MBB.getFirstNonPHI ();
298
341
for (MachineInstr &MI : reverse (MBB)) {
299
342
MachineBasicBlock::iterator MBBI (MI);
300
343
LiveUnits.stepBackward (MI);
@@ -303,8 +346,11 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
303
346
// buffer was allocated in SelectionDAG. It marks the end of the
304
347
// allocation -- which is a safe point for this pass to insert any TPIDR2
305
348
// block setup.
306
- if (MI.getOpcode () == AArch64::SMEStateAllocPseudo)
349
+ if (MI.getOpcode () == AArch64::SMEStateAllocPseudo) {
307
350
State.AfterSMEProloguePt = MBBI;
351
+ State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
352
+ }
353
+ // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
308
354
auto [NeededState, InsertPt] = getZAStateBeforeInst (
309
355
*TRI, MI, /* ZAOffAtReturn=*/ SMEFnAttrs.hasPrivateZAInterface ());
310
356
assert ((InsertPt == MBBI ||
@@ -313,6 +359,8 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
313
359
// TODO: Do something to avoid state changes where NZCV is live.
314
360
if (MBBI == FirstTerminatorInsertPt)
315
361
Block.PhysLiveRegsAtExit = PhysLiveRegs;
362
+ if (MBBI == FirstNonPhiInsertPt)
363
+ Block.PhysLiveRegsAtEntry = PhysLiveRegs;
316
364
if (NeededState != ZAState::ANY)
317
365
Block.Insts .push_back ({NeededState, InsertPt, PhysLiveRegs});
318
366
}
@@ -536,8 +584,6 @@ void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
536
584
void MachineSMEABI::emitAllocateLazySaveBuffer (
537
585
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) {
538
586
MachineFrameInfo &MFI = MF->getFrameInfo ();
539
- auto *AFI = MF->getInfo <AArch64FunctionInfo>();
540
-
541
587
DebugLoc DL = getDebugLoc (MBB, MBBI);
542
588
Register SP = MRI->createVirtualRegister (&AArch64::GPR64RegClass);
543
589
Register SVL = MRI->createVirtualRegister (&AArch64::GPR64RegClass);
@@ -601,8 +647,7 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
601
647
.addImm (AArch64SysReg::TPIDR2_EL0);
602
648
// If TPIDR2_EL0 is non-zero, commit the lazy save.
603
649
// NOTE: Functions that only use ZT0 don't need to zero ZA.
604
- bool ZeroZA =
605
- MF->getInfo <AArch64FunctionInfo>()->getSMEFnAttrs ().hasZAState ();
650
+ bool ZeroZA = AFI->getSMEFnAttrs ().hasZAState ();
606
651
auto CommitZASave =
607
652
BuildMI (MBB, MBBI, DL, TII->get (AArch64::CommitZASavePseudo))
608
653
.addReg (TPIDR2EL0)
@@ -617,6 +662,86 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
617
662
.addImm (1 );
618
663
}
619
664
665
+ Register MachineSMEABI::getAgnosticZABufferPtr () {
666
+ if (State.AgnosticZABufferPtr != AArch64::NoRegister)
667
+ return State.AgnosticZABufferPtr ;
668
+ Register BufferPtr = AFI->getEarlyAllocSMESaveBuffer ();
669
+ State.AgnosticZABufferPtr =
670
+ BufferPtr != AArch64::NoRegister
671
+ ? BufferPtr
672
+ : MF->getRegInfo ().createVirtualRegister (&AArch64::GPR64RegClass);
673
+ return State.AgnosticZABufferPtr ;
674
+ }
675
+
676
+ void MachineSMEABI::emitFullZASaveRestore (MachineBasicBlock &MBB,
677
+ MachineBasicBlock::iterator MBBI,
678
+ LiveRegs PhysLiveRegs, bool IsSave) {
679
+ auto *TLI = Subtarget->getTargetLowering ();
680
+ DebugLoc DL = getDebugLoc (MBB, MBBI);
681
+ Register BufferPtr = AArch64::X0;
682
+
683
+ PhysRegSave RegSave = createPhysRegSave (PhysLiveRegs, MBB, MBBI, DL);
684
+
685
+ // Copy the buffer pointer into X0.
686
+ BuildMI (MBB, MBBI, DL, TII->get (TargetOpcode::COPY), BufferPtr)
687
+ .addReg (getAgnosticZABufferPtr ());
688
+
689
+ // Call __arm_sme_save/__arm_sme_restore.
690
+ BuildMI (MBB, MBBI, DL, TII->get (AArch64::BL))
691
+ .addReg (BufferPtr, RegState::Implicit)
692
+ .addExternalSymbol (TLI->getLibcallName (
693
+ IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE))
694
+ .addRegMask (TRI->getCallPreservedMask (
695
+ *MF,
696
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
697
+
698
+ restorePhyRegSave (RegSave, MBB, MBBI, DL);
699
+ }
700
+
701
+ void MachineSMEABI::emitAllocateFullZASaveBuffer (
702
+ MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
703
+ LiveRegs PhysLiveRegs) {
704
+ // Buffer already allocated in SelectionDAG.
705
+ if (AFI->getEarlyAllocSMESaveBuffer ())
706
+ return ;
707
+
708
+ DebugLoc DL = getDebugLoc (MBB, MBBI);
709
+ Register BufferPtr = getAgnosticZABufferPtr ();
710
+ Register BufferSize = MRI->createVirtualRegister (&AArch64::GPR64RegClass);
711
+
712
+ PhysRegSave RegSave = createPhysRegSave (PhysLiveRegs, MBB, MBBI, DL);
713
+
714
+ // Calculate the SME state size.
715
+ {
716
+ auto *TLI = Subtarget->getTargetLowering ();
717
+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo ();
718
+ BuildMI (MBB, MBBI, DL, TII->get (AArch64::BL))
719
+ .addExternalSymbol (TLI->getLibcallName (RTLIB::SMEABI_SME_STATE_SIZE))
720
+ .addReg (AArch64::X0, RegState::ImplicitDefine)
721
+ .addRegMask (TRI->getCallPreservedMask (
722
+ *MF, CallingConv::
723
+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
724
+ BuildMI (MBB, MBBI, DL, TII->get (TargetOpcode::COPY), BufferSize)
725
+ .addReg (AArch64::X0);
726
+ }
727
+
728
+ // Allocate a buffer object of the size given __arm_sme_state_size.
729
+ {
730
+ MachineFrameInfo &MFI = MF->getFrameInfo ();
731
+ BuildMI (MBB, MBBI, DL, TII->get (AArch64::SUBXrx64), AArch64::SP)
732
+ .addReg (AArch64::SP)
733
+ .addReg (BufferSize)
734
+ .addImm (AArch64_AM::getArithExtendImm (AArch64_AM::UXTX, 0 ));
735
+ BuildMI (MBB, MBBI, DL, TII->get (TargetOpcode::COPY), BufferPtr)
736
+ .addReg (AArch64::SP);
737
+
738
+ // We have just allocated a variable sized object, tell this to PEI.
739
+ MFI.CreateVariableSizedObject (Align (16 ), nullptr );
740
+ }
741
+
742
+ restorePhyRegSave (RegSave, MBB, MBBI, DL);
743
+ }
744
+
620
745
void MachineSMEABI::emitStateChange (MachineBasicBlock &MBB,
621
746
MachineBasicBlock::iterator InsertPt,
622
747
ZAState From, ZAState To,
@@ -634,10 +759,7 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
634
759
// TODO: Avoid setting up the save buffer if there's no transition to
635
760
// LOCAL_SAVED.
636
761
if (From == ZAState::CALLER_DORMANT) {
637
- assert (MBB.getParent ()
638
- ->getInfo <AArch64FunctionInfo>()
639
- ->getSMEFnAttrs ()
640
- .hasPrivateZAInterface () &&
762
+ assert (AFI->getSMEFnAttrs ().hasPrivateZAInterface () &&
641
763
" CALLER_DORMANT state requires private ZA interface" );
642
764
assert (&MBB == &MBB.getParent ()->front () &&
643
765
" CALLER_DORMANT state only valid in entry block" );
@@ -652,12 +774,14 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
652
774
}
653
775
654
776
if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
655
- emitSetupLazySave (MBB, InsertPt);
777
+ emitZASave (MBB, InsertPt, PhysLiveRegs );
656
778
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
657
- emitRestoreLazySave (MBB, InsertPt, PhysLiveRegs);
779
+ emitZARestore (MBB, InsertPt, PhysLiveRegs);
658
780
else if (To == ZAState::OFF) {
659
781
assert (From != ZAState::CALLER_DORMANT &&
660
782
" CALLER_DORMANT to OFF should have already been handled" );
783
+ assert (!AFI->getSMEFnAttrs ().hasAgnosticZAInterface () &&
784
+ " Should not turn ZA off in agnostic ZA function" );
661
785
emitZAOff (MBB, InsertPt, /* ClearTPIDR2=*/ From == ZAState::LOCAL_SAVED);
662
786
} else {
663
787
dbgs () << " Error: Transition from " << getZAStateString (From) << " to "
@@ -675,9 +799,10 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
675
799
if (!MF.getSubtarget <AArch64Subtarget>().hasSME ())
676
800
return false ;
677
801
678
- auto * AFI = MF.getInfo <AArch64FunctionInfo>();
802
+ AFI = MF.getInfo <AArch64FunctionInfo>();
679
803
SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs ();
680
- if (!SMEFnAttrs.hasZAState () && !SMEFnAttrs.hasZT0State ())
804
+ if (!SMEFnAttrs.hasZAState () && !SMEFnAttrs.hasZT0State () &&
805
+ !SMEFnAttrs.hasAgnosticZAInterface ())
681
806
return false ;
682
807
683
808
assert (MF.getRegInfo ().isSSA () && " Expected to be run on SSA form!" );
@@ -696,15 +821,18 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
696
821
insertStateChanges ();
697
822
698
823
// Allocate save buffer (if needed).
699
- if (State.TPIDR2Block ) {
824
+ if (State.AgnosticZABufferPtr != AArch64::NoRegister || State. TPIDR2Block ) {
700
825
if (State.AfterSMEProloguePt ) {
701
826
// Note: With inline stack probes the AfterSMEProloguePt may not be in the
702
827
// entry block (due to the probing loop).
703
- emitAllocateLazySaveBuffer (*(*State.AfterSMEProloguePt )->getParent (),
704
- *State.AfterSMEProloguePt );
828
+ emitAllocateZASaveBuffer (*(*State.AfterSMEProloguePt )->getParent (),
829
+ *State.AfterSMEProloguePt ,
830
+ State.PhysLiveRegsAfterSMEPrologue );
705
831
} else {
706
832
MachineBasicBlock &EntryBlock = MF.front ();
707
- emitAllocateLazySaveBuffer (EntryBlock, EntryBlock.getFirstNonPHI ());
833
+ emitAllocateZASaveBuffer (
834
+ EntryBlock, EntryBlock.getFirstNonPHI (),
835
+ State.Blocks [EntryBlock.getNumber ()].PhysLiveRegsAtEntry );
708
836
}
709
837
}
710
838
0 commit comments