Skip to content

Commit 5a62e13

Browse files
committed
[AArch64][SME] Support agnostic ZA functions in the MachineSMEABIPass
This extends the MachineSMEABIPass to handle agnostic ZA functions. This case is currently handled like shared ZA functions, but we don't require ZA state to be reloaded before agnostic ZA calls. Note: This patch does not yet fully handle agnostic ZA functions that can catch exceptions. E.g.: ``` __arm_agnostic("sme_za_state") void try_catch_agnostic_za_callee() { try { agnostic_za_call(); } catch(...) { noexcept_agnostic_za_call(); } } ``` As in this case, we won't commit a ZA save before the `agnostic_za_call()`, which would be needed to restore ZA in the catch block. This will be handled in a later patch. Change-Id: I9cce7b42ec8b64d5442b35231b65dfaf9d149eed
1 parent 94fd328 commit 5a62e13

File tree

3 files changed

+324
-39
lines changed

3 files changed

+324
-39
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8291,7 +8291,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
82918291
if (Subtarget->hasCustomCallingConv())
82928292
Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
82938293

8294-
if (getTM().useNewSMEABILowering() && !Attrs.hasAgnosticZAInterface()) {
8294+
if (getTM().useNewSMEABILowering()) {
82958295
if (Subtarget->isTargetWindows() || hasInlineStackProbe(MF)) {
82968296
SDValue Size;
82978297
if (Attrs.hasZAState()) {
@@ -9106,9 +9106,13 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91069106
bool UseNewSMEABILowering = getTM().useNewSMEABILowering();
91079107
bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
91089108
auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
9109-
// TODO: Handle agnostic ZA functions.
9110-
if (!UseNewSMEABILowering || IsAgnosticZAFunction)
9109+
if (!UseNewSMEABILowering)
9110+
return std::nullopt;
9111+
if (IsAgnosticZAFunction) {
9112+
if (CallAttrs.requiresPreservingAllZAState())
9113+
return AArch64ISD::REQUIRES_ZA_SAVE;
91119114
return std::nullopt;
9115+
}
91129116
if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
91139117
return std::nullopt;
91149118
return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
@@ -9188,7 +9192,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91889192
};
91899193

91909194
bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
9191-
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
9195+
bool RequiresSaveAllZA =
9196+
!UseNewSMEABILowering && CallAttrs.requiresPreservingAllZAState();
91929197
if (RequiresLazySave) {
91939198
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
91949199
SDValue TPIDR2ObjAddr = DAG.getFrameIndex(

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 157 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This pass implements the SME ABI requirements for ZA state. This includes
10-
// implementing the lazy ZA state save schemes around calls.
10+
// implementing the lazy (and agnostic) ZA state save schemes around calls.
1111
//
1212
//===----------------------------------------------------------------------===//
1313
//
@@ -200,7 +200,7 @@ struct MachineSMEABI : public MachineFunctionPass {
200200

201201
/// Inserts code to handle changes between ZA states within the function.
202202
/// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
203-
void insertStateChanges();
203+
void insertStateChanges(bool IsAgnosticZA);
204204

205205
// Emission routines for private and shared ZA functions (using lazy saves).
206206
void emitNewZAPrologue(MachineBasicBlock &MBB,
@@ -215,8 +215,41 @@ struct MachineSMEABI : public MachineFunctionPass {
215215
void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
216216
bool ClearTPIDR2);
217217

218+
// Emission routines for agnostic ZA functions.
219+
void emitSetupFullZASave(MachineBasicBlock &MBB,
220+
MachineBasicBlock::iterator MBBI,
221+
LiveRegs PhysLiveRegs);
222+
void emitFullZASaveRestore(MachineBasicBlock &MBB,
223+
MachineBasicBlock::iterator MBBI,
224+
LiveRegs PhysLiveRegs, bool IsSave);
225+
void emitAllocateFullZASaveBuffer(MachineBasicBlock &MBB,
226+
MachineBasicBlock::iterator MBBI,
227+
LiveRegs PhysLiveRegs);
228+
218229
void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
219-
ZAState From, ZAState To, LiveRegs PhysLiveRegs);
230+
ZAState From, ZAState To, LiveRegs PhysLiveRegs,
231+
bool IsAgnosticZA);
232+
233+
// Helpers for switching between lazy/full ZA save/restore routines.
234+
void emitZASave(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
235+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
236+
if (IsAgnosticZA)
237+
return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/true);
238+
return emitSetupLazySave(MBB, MBBI);
239+
}
240+
void emitZARestore(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
241+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
242+
if (IsAgnosticZA)
243+
return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/false);
244+
return emitRestoreLazySave(MBB, MBBI, PhysLiveRegs);
245+
}
246+
void emitAllocateZASaveBuffer(MachineBasicBlock &MBB,
247+
MachineBasicBlock::iterator MBBI,
248+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
249+
if (IsAgnosticZA)
250+
return emitAllocateFullZASaveBuffer(MBB, MBBI, PhysLiveRegs);
251+
return emitAllocateLazySaveBuffer(MBB, MBBI);
252+
}
220253

221254
/// Save live physical registers to virtual registers.
222255
PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
@@ -228,6 +261,8 @@ struct MachineSMEABI : public MachineFunctionPass {
228261
/// Get or create a TPIDR2 block in this function.
229262
TPIDR2State getTPIDR2Block();
230263

264+
Register getAgnosticZABufferPtr();
265+
231266
private:
232267
/// Contains the needed ZA state (and live registers) at an instruction.
233268
struct InstInfo {
@@ -241,6 +276,7 @@ struct MachineSMEABI : public MachineFunctionPass {
241276
struct BlockInfo {
242277
ZAState FixedEntryState{ZAState::ANY};
243278
SmallVector<InstInfo> Insts;
279+
LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
244280
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
245281
};
246282

@@ -250,6 +286,9 @@ struct MachineSMEABI : public MachineFunctionPass {
250286
SmallVector<ZAState> BundleStates;
251287
std::optional<TPIDR2State> TPIDR2Block;
252288
std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
289+
Register AgnosticZABufferPtr = AArch64::NoRegister;
290+
LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
291+
bool HasFullZASaveRestore = false;
253292
} State;
254293

255294
MachineFunction *MF = nullptr;
@@ -261,7 +300,8 @@ struct MachineSMEABI : public MachineFunctionPass {
261300
};
262301

263302
void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
264-
assert((SMEFnAttrs.hasZT0State() || SMEFnAttrs.hasZAState()) &&
303+
assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
304+
SMEFnAttrs.hasZAState()) &&
265305
"Expected function to have ZA/ZT0 state!");
266306

267307
State.Blocks.resize(MF->getNumBlockIDs());
@@ -295,6 +335,7 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
295335

296336
Block.PhysLiveRegsAtExit = GetPhysLiveRegs();
297337
auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
338+
auto FirstNonPhiInsertPt = MBB.getFirstNonPHI();
298339
for (MachineInstr &MI : reverse(MBB)) {
299340
MachineBasicBlock::iterator MBBI(MI);
300341
LiveUnits.stepBackward(MI);
@@ -303,8 +344,11 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
303344
// buffer was allocated in SelectionDAG. It marks the end of the
304345
// allocation -- which is a safe point for this pass to insert any TPIDR2
305346
// block setup.
306-
if (MI.getOpcode() == AArch64::SMEStateAllocPseudo)
347+
if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
307348
State.AfterSMEProloguePt = MBBI;
349+
State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
350+
}
351+
// Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
308352
auto [NeededState, InsertPt] = getZAStateBeforeInst(
309353
*TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface());
310354
assert((InsertPt == MBBI ||
@@ -313,6 +357,8 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
313357
// TODO: Do something to avoid state changes where NZCV is live.
314358
if (MBBI == FirstTerminatorInsertPt)
315359
Block.PhysLiveRegsAtExit = PhysLiveRegs;
360+
if (MBBI == FirstNonPhiInsertPt)
361+
Block.PhysLiveRegsAtEntry = PhysLiveRegs;
316362
if (NeededState != ZAState::ANY)
317363
Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs});
318364
}
@@ -379,7 +425,7 @@ void MachineSMEABI::assignBundleZAStates() {
379425
}
380426
}
381427

382-
void MachineSMEABI::insertStateChanges() {
428+
void MachineSMEABI::insertStateChanges(bool IsAgnosticZA) {
383429
for (MachineBasicBlock &MBB : *MF) {
384430
const BlockInfo &Block = State.Blocks[MBB.getNumber()];
385431
ZAState InState = State.BundleStates[Bundles->getBundle(MBB.getNumber(),
@@ -392,7 +438,7 @@ void MachineSMEABI::insertStateChanges() {
392438
for (auto &Inst : Block.Insts) {
393439
if (CurrentState != Inst.NeededState)
394440
emitStateChange(MBB, Inst.InsertPt, CurrentState, Inst.NeededState,
395-
Inst.PhysLiveRegs);
441+
Inst.PhysLiveRegs, IsAgnosticZA);
396442
CurrentState = Inst.NeededState;
397443
}
398444

@@ -403,7 +449,7 @@ void MachineSMEABI::insertStateChanges() {
403449
State.BundleStates[Bundles->getBundle(MBB.getNumber(), /*Out=*/true)];
404450
if (CurrentState != OutState)
405451
emitStateChange(MBB, MBB.getFirstTerminator(), CurrentState, OutState,
406-
Block.PhysLiveRegsAtExit);
452+
Block.PhysLiveRegsAtExit, IsAgnosticZA);
407453
}
408454
}
409455

@@ -617,10 +663,95 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
617663
.addImm(1);
618664
}
619665

666+
Register MachineSMEABI::getAgnosticZABufferPtr() {
667+
if (State.AgnosticZABufferPtr != AArch64::NoRegister)
668+
return State.AgnosticZABufferPtr;
669+
if (auto BufferPtr =
670+
MF->getInfo<AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer();
671+
BufferPtr != AArch64::NoRegister)
672+
State.AgnosticZABufferPtr = BufferPtr;
673+
else
674+
State.AgnosticZABufferPtr =
675+
MF->getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
676+
return State.AgnosticZABufferPtr;
677+
}
678+
679+
void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB,
680+
MachineBasicBlock::iterator MBBI,
681+
LiveRegs PhysLiveRegs, bool IsSave) {
682+
auto *TLI = Subtarget->getTargetLowering();
683+
State.HasFullZASaveRestore = true;
684+
DebugLoc DL = getDebugLoc(MBB, MBBI);
685+
Register BufferPtr = AArch64::X0;
686+
687+
PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
688+
689+
// Copy the buffer pointer into X0.
690+
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
691+
.addReg(getAgnosticZABufferPtr());
692+
693+
// Call __arm_sme_save/__arm_sme_restore.
694+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
695+
.addReg(BufferPtr, RegState::Implicit)
696+
.addExternalSymbol(TLI->getLibcallName(
697+
IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE))
698+
.addRegMask(TRI->getCallPreservedMask(
699+
*MF,
700+
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
701+
702+
restorePhyRegSave(RegSave, MBB, MBBI, DL);
703+
}
704+
705+
void MachineSMEABI::emitAllocateFullZASaveBuffer(
706+
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
707+
LiveRegs PhysLiveRegs) {
708+
auto *AFI = MF->getInfo<AArch64FunctionInfo>();
709+
710+
// Buffer already allocated in SelectionDAG.
711+
if (AFI->getEarlyAllocSMESaveBuffer())
712+
return;
713+
714+
DebugLoc DL = getDebugLoc(MBB, MBBI);
715+
Register BufferPtr = getAgnosticZABufferPtr();
716+
Register BufferSize = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
717+
718+
PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
719+
720+
// Calculate the SME state size.
721+
{
722+
auto *TLI = Subtarget->getTargetLowering();
723+
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
724+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
725+
.addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_SME_STATE_SIZE))
726+
.addReg(AArch64::X0, RegState::ImplicitDefine)
727+
.addRegMask(TRI->getCallPreservedMask(
728+
*MF, CallingConv::
729+
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
730+
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferSize)
731+
.addReg(AArch64::X0);
732+
}
733+
734+
// Allocate a buffer object of the size given __arm_sme_state_size.
735+
{
736+
MachineFrameInfo &MFI = MF->getFrameInfo();
737+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::SUBXrx64), AArch64::SP)
738+
.addReg(AArch64::SP)
739+
.addReg(BufferSize)
740+
.addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
741+
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
742+
.addReg(AArch64::SP);
743+
744+
// We have just allocated a variable sized object, tell this to PEI.
745+
MFI.CreateVariableSizedObject(Align(16), nullptr);
746+
}
747+
748+
restorePhyRegSave(RegSave, MBB, MBBI, DL);
749+
}
750+
620751
void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
621752
MachineBasicBlock::iterator InsertPt,
622753
ZAState From, ZAState To,
623-
LiveRegs PhysLiveRegs) {
754+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
624755

625756
// ZA not used.
626757
if (From == ZAState::ANY || To == ZAState::ANY)
@@ -652,12 +783,13 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
652783
}
653784

654785
if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
655-
emitSetupLazySave(MBB, InsertPt);
786+
emitZASave(MBB, InsertPt, PhysLiveRegs, IsAgnosticZA);
656787
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
657-
emitRestoreLazySave(MBB, InsertPt, PhysLiveRegs);
788+
emitZARestore(MBB, InsertPt, PhysLiveRegs, IsAgnosticZA);
658789
else if (To == ZAState::OFF) {
659790
assert(From != ZAState::CALLER_DORMANT &&
660791
"CALLER_DORMANT to OFF should have already been handled");
792+
assert(!IsAgnosticZA && "Should not turn ZA off in agnostic ZA function");
661793
emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
662794
} else {
663795
dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
@@ -677,7 +809,8 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
677809

678810
auto *AFI = MF.getInfo<AArch64FunctionInfo>();
679811
SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
680-
if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State())
812+
if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State() &&
813+
!SMEFnAttrs.hasAgnosticZAInterface())
681814
return false;
682815

683816
assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
@@ -691,20 +824,27 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
691824
TRI = Subtarget->getRegisterInfo();
692825
MRI = &MF.getRegInfo();
693826

827+
bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
828+
694829
collectNeededZAStates(SMEFnAttrs);
695830
assignBundleZAStates();
696-
insertStateChanges();
831+
insertStateChanges(/*IsAgnosticZA=*/IsAgnosticZA);
697832

698833
// Allocate save buffer (if needed).
699-
if (State.TPIDR2Block) {
834+
if (State.HasFullZASaveRestore || State.TPIDR2Block) {
700835
if (State.AfterSMEProloguePt) {
701836
// Note: With inline stack probes the AfterSMEProloguePt may not be in the
702837
// entry block (due to the probing loop).
703-
emitAllocateLazySaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
704-
*State.AfterSMEProloguePt);
838+
emitAllocateZASaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
839+
*State.AfterSMEProloguePt,
840+
State.PhysLiveRegsAfterSMEPrologue,
841+
/*IsAgnosticZA=*/IsAgnosticZA);
705842
} else {
706843
MachineBasicBlock &EntryBlock = MF.front();
707-
emitAllocateLazySaveBuffer(EntryBlock, EntryBlock.getFirstNonPHI());
844+
emitAllocateZASaveBuffer(
845+
EntryBlock, EntryBlock.getFirstNonPHI(),
846+
State.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry,
847+
/*IsAgnosticZA=*/IsAgnosticZA);
708848
}
709849
}
710850

0 commit comments

Comments
 (0)