Skip to content

Commit 75b2bf0

Browse files
committed
[AArch64][SME] Support agnostic ZA functions in the MachineSMEABIPass
This extends the MachineSMEABIPass to handle agnostic ZA functions. This case is currently handled like shared ZA functions, but we don't require ZA state to be reloaded before agnostic ZA calls. Note: This patch does not yet fully handle agnostic ZA functions that can catch exceptions. E.g.: ``` __arm_agnostic("sme_za_state") void try_catch_agnostic_za_callee() { try { agnostic_za_call(); } catch(...) { noexcept_agnostic_za_call(); } } ``` As in this case, we won't commit a ZA save before the `agnostic_za_call()`, which would be needed to restore ZA in the catch block. This will be handled in a later patch. Change-Id: I9cce7b42ec8b64d5442b35231b65dfaf9d149eed
1 parent b18d8f1 commit 75b2bf0

File tree

3 files changed

+322
-38
lines changed

3 files changed

+322
-38
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8290,7 +8290,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
82908290
if (Subtarget->hasCustomCallingConv())
82918291
Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
82928292

8293-
if (getTM().useNewSMEABILowering() && !Attrs.hasAgnosticZAInterface()) {
8293+
if (getTM().useNewSMEABILowering()) {
82948294
if (Subtarget->isTargetWindows() || hasInlineStackProbe(MF)) {
82958295
SDValue Size;
82968296
if (Attrs.hasZAState()) {
@@ -9111,9 +9111,13 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91119111
bool UseNewSMEABILowering = getTM().useNewSMEABILowering();
91129112
bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
91139113
auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
9114-
// TODO: Handle agnostic ZA functions.
9115-
if (!UseNewSMEABILowering || IsAgnosticZAFunction)
9114+
if (!UseNewSMEABILowering)
9115+
return std::nullopt;
9116+
if (IsAgnosticZAFunction) {
9117+
if (CallAttrs.requiresPreservingAllZAState())
9118+
return AArch64ISD::REQUIRES_ZA_SAVE;
91169119
return std::nullopt;
9120+
}
91179121
if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
91189122
return std::nullopt;
91199123
return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
@@ -9193,7 +9197,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91939197
};
91949198

91959199
bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
9196-
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
9200+
bool RequiresSaveAllZA =
9201+
!UseNewSMEABILowering && CallAttrs.requiresPreservingAllZAState();
91979202
if (RequiresLazySave) {
91989203
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
91999204
MachinePointerInfo MPI =

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 155 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This pass implements the SME ABI requirements for ZA state. This includes
10-
// implementing the lazy ZA state save schemes around calls.
10+
// implementing the lazy (and agnostic) ZA state save schemes around calls.
1111
//
1212
//===----------------------------------------------------------------------===//
1313
//
@@ -200,7 +200,7 @@ struct MachineSMEABI : public MachineFunctionPass {
200200

201201
/// Inserts code to handle changes between ZA states within the function.
202202
/// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
203-
void insertStateChanges();
203+
void insertStateChanges(bool IsAgnosticZA);
204204

205205
// Emission routines for private and shared ZA functions (using lazy saves).
206206
void emitNewZAPrologue(MachineBasicBlock &MBB,
@@ -215,8 +215,41 @@ struct MachineSMEABI : public MachineFunctionPass {
215215
void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
216216
bool ClearTPIDR2);
217217

218+
// Emission routines for agnostic ZA functions.
219+
void emitSetupFullZASave(MachineBasicBlock &MBB,
220+
MachineBasicBlock::iterator MBBI,
221+
LiveRegs PhysLiveRegs);
222+
void emitFullZASaveRestore(MachineBasicBlock &MBB,
223+
MachineBasicBlock::iterator MBBI,
224+
LiveRegs PhysLiveRegs, bool IsSave);
225+
void emitAllocateFullZASaveBuffer(MachineBasicBlock &MBB,
226+
MachineBasicBlock::iterator MBBI,
227+
LiveRegs PhysLiveRegs);
228+
218229
void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
219-
ZAState From, ZAState To, LiveRegs PhysLiveRegs);
230+
ZAState From, ZAState To, LiveRegs PhysLiveRegs,
231+
bool IsAgnosticZA);
232+
233+
// Helpers for switching between lazy/full ZA save/restore routines.
234+
void emitZASave(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
235+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
236+
if (IsAgnosticZA)
237+
return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/true);
238+
return emitSetupLazySave(MBB, MBBI);
239+
}
240+
void emitZARestore(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
241+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
242+
if (IsAgnosticZA)
243+
return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/false);
244+
return emitRestoreLazySave(MBB, MBBI, PhysLiveRegs);
245+
}
246+
void emitAllocateZASaveBuffer(MachineBasicBlock &MBB,
247+
MachineBasicBlock::iterator MBBI,
248+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
249+
if (IsAgnosticZA)
250+
return emitAllocateFullZASaveBuffer(MBB, MBBI, PhysLiveRegs);
251+
return emitAllocateLazySaveBuffer(MBB, MBBI);
252+
}
220253

221254
/// Save live physical registers to virtual registers.
222255
PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
@@ -228,6 +261,8 @@ struct MachineSMEABI : public MachineFunctionPass {
228261
/// Get or create a TPIDR2 block in this function.
229262
TPIDR2State getTPIDR2Block();
230263

264+
Register getAgnosticZABufferPtr();
265+
231266
private:
232267
/// Contains the needed ZA state (and live registers) at an instruction.
233268
struct InstInfo {
@@ -241,6 +276,7 @@ struct MachineSMEABI : public MachineFunctionPass {
241276
struct BlockInfo {
242277
ZAState FixedEntryState{ZAState::ANY};
243278
SmallVector<InstInfo> Insts;
279+
LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
244280
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
245281
};
246282

@@ -250,6 +286,9 @@ struct MachineSMEABI : public MachineFunctionPass {
250286
SmallVector<ZAState> BundleStates;
251287
std::optional<TPIDR2State> TPIDR2Block;
252288
std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
289+
Register AgnosticZABufferPtr = AArch64::NoRegister;
290+
LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
291+
bool HasFullZASaveRestore = false;
253292
} State;
254293

255294
MachineFunction *MF = nullptr;
@@ -261,7 +300,8 @@ struct MachineSMEABI : public MachineFunctionPass {
261300
};
262301

263302
void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
264-
assert((SMEFnAttrs.hasZT0State() || SMEFnAttrs.hasZAState()) &&
303+
assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
304+
SMEFnAttrs.hasZAState()) &&
265305
"Expected function to have ZA/ZT0 state!");
266306

267307
State.Blocks.resize(MF->getNumBlockIDs());
@@ -295,6 +335,7 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
295335

296336
Block.PhysLiveRegsAtExit = GetPhysLiveRegs();
297337
auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
338+
auto FirstNonPhiInsertPt = MBB.getFirstNonPHI();
298339
for (MachineInstr &MI : reverse(MBB)) {
299340
MachineBasicBlock::iterator MBBI(MI);
300341
LiveUnits.stepBackward(MI);
@@ -305,7 +346,9 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
305346
// block setup.
306347
if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
307348
State.AfterSMEProloguePt = MBBI;
349+
State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
308350
}
351+
// Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
309352
auto [NeededState, InsertPt] = getZAStateBeforeInst(
310353
*TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface());
311354
assert((InsertPt == MBBI ||
@@ -314,6 +357,8 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
314357
// TODO: Do something to avoid state changes where NZCV is live.
315358
if (MBBI == FirstTerminatorInsertPt)
316359
Block.PhysLiveRegsAtExit = PhysLiveRegs;
360+
if (MBBI == FirstNonPhiInsertPt)
361+
Block.PhysLiveRegsAtEntry = PhysLiveRegs;
317362
if (NeededState != ZAState::ANY)
318363
Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs});
319364
}
@@ -380,7 +425,7 @@ void MachineSMEABI::assignBundleZAStates() {
380425
}
381426
}
382427

383-
void MachineSMEABI::insertStateChanges() {
428+
void MachineSMEABI::insertStateChanges(bool IsAgnosticZA) {
384429
for (MachineBasicBlock &MBB : *MF) {
385430
const BlockInfo &Block = State.Blocks[MBB.getNumber()];
386431
ZAState InState = State.BundleStates[Bundles->getBundle(MBB.getNumber(),
@@ -393,7 +438,7 @@ void MachineSMEABI::insertStateChanges() {
393438
for (auto &Inst : Block.Insts) {
394439
if (CurrentState != Inst.NeededState)
395440
emitStateChange(MBB, Inst.InsertPt, CurrentState, Inst.NeededState,
396-
Inst.PhysLiveRegs);
441+
Inst.PhysLiveRegs, IsAgnosticZA);
397442
CurrentState = Inst.NeededState;
398443
}
399444

@@ -404,7 +449,7 @@ void MachineSMEABI::insertStateChanges() {
404449
State.BundleStates[Bundles->getBundle(MBB.getNumber(), /*Out=*/true)];
405450
if (CurrentState != OutState)
406451
emitStateChange(MBB, MBB.getFirstTerminator(), CurrentState, OutState,
407-
Block.PhysLiveRegsAtExit);
452+
Block.PhysLiveRegsAtExit, IsAgnosticZA);
408453
}
409454
}
410455

@@ -618,10 +663,95 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
618663
.addImm(1);
619664
}
620665

666+
Register MachineSMEABI::getAgnosticZABufferPtr() {
667+
if (State.AgnosticZABufferPtr != AArch64::NoRegister)
668+
return State.AgnosticZABufferPtr;
669+
if (auto BufferPtr =
670+
MF->getInfo<AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer();
671+
BufferPtr != AArch64::NoRegister)
672+
State.AgnosticZABufferPtr = BufferPtr;
673+
else
674+
State.AgnosticZABufferPtr =
675+
MF->getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
676+
return State.AgnosticZABufferPtr;
677+
}
678+
679+
void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB,
680+
MachineBasicBlock::iterator MBBI,
681+
LiveRegs PhysLiveRegs, bool IsSave) {
682+
auto *TLI = Subtarget->getTargetLowering();
683+
State.HasFullZASaveRestore = true;
684+
DebugLoc DL = getDebugLoc(MBB, MBBI);
685+
Register BufferPtr = AArch64::X0;
686+
687+
PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
688+
689+
// Copy the buffer pointer into X0.
690+
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
691+
.addReg(getAgnosticZABufferPtr());
692+
693+
// Call __arm_sme_save/__arm_sme_restore.
694+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
695+
.addReg(BufferPtr, RegState::Implicit)
696+
.addExternalSymbol(TLI->getLibcallName(
697+
IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE))
698+
.addRegMask(TRI->getCallPreservedMask(
699+
*MF,
700+
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
701+
702+
restorePhyRegSave(RegSave, MBB, MBBI, DL);
703+
}
704+
705+
void MachineSMEABI::emitAllocateFullZASaveBuffer(
706+
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
707+
LiveRegs PhysLiveRegs) {
708+
auto *AFI = MF->getInfo<AArch64FunctionInfo>();
709+
710+
// Buffer already allocated in SelectionDAG.
711+
if (AFI->getEarlyAllocSMESaveBuffer())
712+
return;
713+
714+
DebugLoc DL = getDebugLoc(MBB, MBBI);
715+
Register BufferPtr = getAgnosticZABufferPtr();
716+
Register BufferSize = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
717+
718+
PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
719+
720+
// Calculate the SME state size.
721+
{
722+
auto *TLI = Subtarget->getTargetLowering();
723+
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
724+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
725+
.addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_SME_STATE_SIZE))
726+
.addReg(AArch64::X0, RegState::ImplicitDefine)
727+
.addRegMask(TRI->getCallPreservedMask(
728+
*MF, CallingConv::
729+
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
730+
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferSize)
731+
.addReg(AArch64::X0);
732+
}
733+
734+
// Allocate a buffer object of the size given __arm_sme_state_size.
735+
{
736+
MachineFrameInfo &MFI = MF->getFrameInfo();
737+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::SUBXrx64), AArch64::SP)
738+
.addReg(AArch64::SP)
739+
.addReg(BufferSize)
740+
.addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
741+
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
742+
.addReg(AArch64::SP);
743+
744+
// We have just allocated a variable sized object, tell this to PEI.
745+
MFI.CreateVariableSizedObject(Align(16), nullptr);
746+
}
747+
748+
restorePhyRegSave(RegSave, MBB, MBBI, DL);
749+
}
750+
621751
void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
622752
MachineBasicBlock::iterator InsertPt,
623753
ZAState From, ZAState To,
624-
LiveRegs PhysLiveRegs) {
754+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
625755

626756
// ZA not used.
627757
if (From == ZAState::ANY || To == ZAState::ANY)
@@ -653,12 +783,13 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
653783
}
654784

655785
if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
656-
emitSetupLazySave(MBB, InsertPt);
786+
emitZASave(MBB, InsertPt, PhysLiveRegs, IsAgnosticZA);
657787
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
658-
emitRestoreLazySave(MBB, InsertPt, PhysLiveRegs);
788+
emitZARestore(MBB, InsertPt, PhysLiveRegs, IsAgnosticZA);
659789
else if (To == ZAState::OFF) {
660790
assert(From != ZAState::CALLER_DORMANT &&
661791
"CALLER_DORMANT to OFF should have already been handled");
792+
assert(!IsAgnosticZA && "Should not turn ZA off in agnostic ZA function");
662793
emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
663794
} else {
664795
dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
@@ -678,7 +809,8 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
678809

679810
auto *AFI = MF.getInfo<AArch64FunctionInfo>();
680811
SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
681-
if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State())
812+
if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State() &&
813+
!SMEFnAttrs.hasAgnosticZAInterface())
682814
return false;
683815

684816
assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
@@ -692,20 +824,27 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
692824
TRI = Subtarget->getRegisterInfo();
693825
MRI = &MF.getRegInfo();
694826

827+
bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
828+
695829
collectNeededZAStates(SMEFnAttrs);
696830
assignBundleZAStates();
697-
insertStateChanges();
831+
insertStateChanges(/*IsAgnosticZA=*/IsAgnosticZA);
698832

699833
// Allocate save buffer (if needed).
700-
if (State.TPIDR2Block) {
834+
if (State.HasFullZASaveRestore || State.TPIDR2Block) {
701835
if (State.AfterSMEProloguePt) {
702836
// Note: With inline stack probes the AfterSMEProloguePt may not be in the
703837
// entry block (due to the probing loop).
704-
emitAllocateLazySaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
705-
*State.AfterSMEProloguePt);
838+
emitAllocateZASaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
839+
*State.AfterSMEProloguePt,
840+
State.PhysLiveRegsAfterSMEPrologue,
841+
/*IsAgnosticZA=*/IsAgnosticZA);
706842
} else {
707843
MachineBasicBlock &EntryBlock = MF.front();
708-
emitAllocateLazySaveBuffer(EntryBlock, EntryBlock.getFirstNonPHI());
844+
emitAllocateZASaveBuffer(
845+
EntryBlock, EntryBlock.getFirstNonPHI(),
846+
State.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry,
847+
/*IsAgnosticZA=*/IsAgnosticZA);
709848
}
710849
}
711850

0 commit comments

Comments
 (0)