Skip to content

Commit 96d5567

Browse files
authored
[AArch64][SME] Support agnostic ZA functions in the MachineSMEABIPass (#149064)
This extends the MachineSMEABIPass to handle agnostic ZA functions. This case is currently handled like shared ZA functions, but we don't require ZA state to be reloaded before agnostic ZA calls. Note: This patch does not yet fully handle agnostic ZA functions that can catch exceptions. E.g.: ``` __arm_agnostic("sme_za_state") void try_catch_agnostic_za_callee() { try { agnostic_za_call(); } catch(...) { noexcept_agnostic_za_call(); } } ``` As in this case, we won't commit a ZA save before the `agnostic_za_call()`, which would be needed to restore ZA in the catch block. This will be handled in a later patch.
1 parent 4546522 commit 96d5567

File tree

4 files changed

+388
-51
lines changed

4 files changed

+388
-51
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8489,13 +8489,22 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
84898489
if (Subtarget->hasCustomCallingConv())
84908490
Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
84918491

8492-
if (getTM().useNewSMEABILowering() && !Attrs.hasAgnosticZAInterface()) {
8492+
if (getTM().useNewSMEABILowering()) {
84938493
if (Subtarget->isTargetWindows() || hasInlineStackProbe(MF)) {
84948494
SDValue Size;
84958495
if (Attrs.hasZAState()) {
84968496
SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
84978497
DAG.getConstant(1, DL, MVT::i32));
84988498
Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8499+
} else if (Attrs.hasAgnosticZAInterface()) {
8500+
RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE_SIZE;
8501+
SDValue Callee = DAG.getExternalSymbol(
8502+
getLibcallName(LC), getPointerTy(DAG.getDataLayout()));
8503+
auto *RetTy = EVT(MVT::i64).getTypeForEVT(*DAG.getContext());
8504+
TargetLowering::CallLoweringInfo CLI(DAG);
8505+
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
8506+
getLibcallCallingConv(LC), RetTy, Callee, {});
8507+
std::tie(Size, Chain) = LowerCallTo(CLI);
84998508
}
85008509
if (Size) {
85018510
SDValue Buffer = DAG.getNode(
@@ -8561,7 +8570,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
85618570
Register BufferPtr =
85628571
MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
85638572
FuncInfo->setSMESaveBufferAddr(BufferPtr);
8564-
Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
8573+
Chain = DAG.getCopyToReg(Buffer.getValue(1), DL, BufferPtr, Buffer);
85658574
}
85668575
}
85678576

@@ -9300,17 +9309,17 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
93009309

93019310
// Determine whether we need any streaming mode changes.
93029311
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
9312+
9313+
std::optional<unsigned> ZAMarkerNode;
93039314
bool UseNewSMEABILowering = getTM().useNewSMEABILowering();
9304-
bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
9305-
auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
9306-
// TODO: Handle agnostic ZA functions.
9307-
if (!UseNewSMEABILowering || IsAgnosticZAFunction)
9308-
return std::nullopt;
9309-
if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
9310-
return std::nullopt;
9311-
return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
9312-
: AArch64ISD::INOUT_ZA_USE;
9313-
}();
9315+
if (UseNewSMEABILowering) {
9316+
if (CallAttrs.requiresLazySave() ||
9317+
CallAttrs.requiresPreservingAllZAState())
9318+
ZAMarkerNode = AArch64ISD::REQUIRES_ZA_SAVE;
9319+
else if (CallAttrs.caller().hasZAState() ||
9320+
CallAttrs.caller().hasZT0State())
9321+
ZAMarkerNode = AArch64ISD::INOUT_ZA_USE;
9322+
}
93149323

93159324
if (IsTailCall) {
93169325
// Check if it's really possible to do a tail call.
@@ -9385,7 +9394,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
93859394
};
93869395

93879396
bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
9388-
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
9397+
bool RequiresSaveAllZA =
9398+
!UseNewSMEABILowering && CallAttrs.requiresPreservingAllZAState();
93899399
if (RequiresLazySave) {
93909400
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
93919401
SDValue TPIDR2ObjAddr = DAG.getFrameIndex(

llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
261261
EarlyAllocSMESaveBuffer = Ptr;
262262
}
263263

264-
Register getEarlyAllocSMESaveBuffer() { return EarlyAllocSMESaveBuffer; }
264+
Register getEarlyAllocSMESaveBuffer() const {
265+
return EarlyAllocSMESaveBuffer;
266+
}
265267

266268
// Old SME ABI lowering state getters/setters:
267269
Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 147 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This pass implements the SME ABI requirements for ZA state. This includes
10-
// implementing the lazy ZA state save schemes around calls.
10+
// implementing the lazy (and agnostic) ZA state save schemes around calls.
1111
//
1212
//===----------------------------------------------------------------------===//
1313
//
@@ -215,9 +215,44 @@ struct MachineSMEABI : public MachineFunctionPass {
215215
void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
216216
bool ClearTPIDR2);
217217

218+
// Emission routines for agnostic ZA functions.
219+
void emitSetupFullZASave(MachineBasicBlock &MBB,
220+
MachineBasicBlock::iterator MBBI,
221+
LiveRegs PhysLiveRegs);
222+
// Emit a "full" ZA save or restore. It is "full" in the sense that this
223+
// function will emit a call to __arm_sme_save or __arm_sme_restore, which
224+
// handles saving and restoring both ZA and ZT0.
225+
void emitFullZASaveRestore(MachineBasicBlock &MBB,
226+
MachineBasicBlock::iterator MBBI,
227+
LiveRegs PhysLiveRegs, bool IsSave);
228+
void emitAllocateFullZASaveBuffer(MachineBasicBlock &MBB,
229+
MachineBasicBlock::iterator MBBI,
230+
LiveRegs PhysLiveRegs);
231+
218232
void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
219233
ZAState From, ZAState To, LiveRegs PhysLiveRegs);
220234

235+
// Helpers for switching between lazy/full ZA save/restore routines.
236+
void emitZASave(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
237+
LiveRegs PhysLiveRegs) {
238+
if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
239+
return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/true);
240+
return emitSetupLazySave(MBB, MBBI);
241+
}
242+
void emitZARestore(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
243+
LiveRegs PhysLiveRegs) {
244+
if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
245+
return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/false);
246+
return emitRestoreLazySave(MBB, MBBI, PhysLiveRegs);
247+
}
248+
void emitAllocateZASaveBuffer(MachineBasicBlock &MBB,
249+
MachineBasicBlock::iterator MBBI,
250+
LiveRegs PhysLiveRegs) {
251+
if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
252+
return emitAllocateFullZASaveBuffer(MBB, MBBI, PhysLiveRegs);
253+
return emitAllocateLazySaveBuffer(MBB, MBBI);
254+
}
255+
221256
/// Save live physical registers to virtual registers.
222257
PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
223258
MachineBasicBlock::iterator MBBI, DebugLoc DL);
@@ -228,6 +263,8 @@ struct MachineSMEABI : public MachineFunctionPass {
228263
/// Get or create a TPIDR2 block in this function.
229264
TPIDR2State getTPIDR2Block();
230265

266+
Register getAgnosticZABufferPtr();
267+
231268
private:
232269
/// Contains the needed ZA state (and live registers) at an instruction.
233270
struct InstInfo {
@@ -241,6 +278,7 @@ struct MachineSMEABI : public MachineFunctionPass {
241278
struct BlockInfo {
242279
ZAState FixedEntryState{ZAState::ANY};
243280
SmallVector<InstInfo> Insts;
281+
LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
244282
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
245283
};
246284

@@ -250,18 +288,22 @@ struct MachineSMEABI : public MachineFunctionPass {
250288
SmallVector<ZAState> BundleStates;
251289
std::optional<TPIDR2State> TPIDR2Block;
252290
std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
291+
Register AgnosticZABufferPtr = AArch64::NoRegister;
292+
LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
253293
} State;
254294

255295
MachineFunction *MF = nullptr;
256296
EdgeBundles *Bundles = nullptr;
257297
const AArch64Subtarget *Subtarget = nullptr;
258298
const AArch64RegisterInfo *TRI = nullptr;
299+
const AArch64FunctionInfo *AFI = nullptr;
259300
const TargetInstrInfo *TII = nullptr;
260301
MachineRegisterInfo *MRI = nullptr;
261302
};
262303

263304
void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
264-
assert((SMEFnAttrs.hasZT0State() || SMEFnAttrs.hasZAState()) &&
305+
assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
306+
SMEFnAttrs.hasZAState()) &&
265307
"Expected function to have ZA/ZT0 state!");
266308

267309
State.Blocks.resize(MF->getNumBlockIDs());
@@ -295,6 +337,7 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
295337

296338
Block.PhysLiveRegsAtExit = GetPhysLiveRegs();
297339
auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
340+
auto FirstNonPhiInsertPt = MBB.getFirstNonPHI();
298341
for (MachineInstr &MI : reverse(MBB)) {
299342
MachineBasicBlock::iterator MBBI(MI);
300343
LiveUnits.stepBackward(MI);
@@ -303,8 +346,11 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
303346
// buffer was allocated in SelectionDAG. It marks the end of the
304347
// allocation -- which is a safe point for this pass to insert any TPIDR2
305348
// block setup.
306-
if (MI.getOpcode() == AArch64::SMEStateAllocPseudo)
349+
if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
307350
State.AfterSMEProloguePt = MBBI;
351+
State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
352+
}
353+
// Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
308354
auto [NeededState, InsertPt] = getZAStateBeforeInst(
309355
*TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface());
310356
assert((InsertPt == MBBI ||
@@ -313,6 +359,8 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
313359
// TODO: Do something to avoid state changes where NZCV is live.
314360
if (MBBI == FirstTerminatorInsertPt)
315361
Block.PhysLiveRegsAtExit = PhysLiveRegs;
362+
if (MBBI == FirstNonPhiInsertPt)
363+
Block.PhysLiveRegsAtEntry = PhysLiveRegs;
316364
if (NeededState != ZAState::ANY)
317365
Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs});
318366
}
@@ -536,8 +584,6 @@ void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
536584
void MachineSMEABI::emitAllocateLazySaveBuffer(
537585
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) {
538586
MachineFrameInfo &MFI = MF->getFrameInfo();
539-
auto *AFI = MF->getInfo<AArch64FunctionInfo>();
540-
541587
DebugLoc DL = getDebugLoc(MBB, MBBI);
542588
Register SP = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
543589
Register SVL = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
@@ -601,8 +647,7 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
601647
.addImm(AArch64SysReg::TPIDR2_EL0);
602648
// If TPIDR2_EL0 is non-zero, commit the lazy save.
603649
// NOTE: Functions that only use ZT0 don't need to zero ZA.
604-
bool ZeroZA =
605-
MF->getInfo<AArch64FunctionInfo>()->getSMEFnAttrs().hasZAState();
650+
bool ZeroZA = AFI->getSMEFnAttrs().hasZAState();
606651
auto CommitZASave =
607652
BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
608653
.addReg(TPIDR2EL0)
@@ -617,6 +662,86 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
617662
.addImm(1);
618663
}
619664

665+
Register MachineSMEABI::getAgnosticZABufferPtr() {
666+
if (State.AgnosticZABufferPtr != AArch64::NoRegister)
667+
return State.AgnosticZABufferPtr;
668+
Register BufferPtr = AFI->getEarlyAllocSMESaveBuffer();
669+
State.AgnosticZABufferPtr =
670+
BufferPtr != AArch64::NoRegister
671+
? BufferPtr
672+
: MF->getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
673+
return State.AgnosticZABufferPtr;
674+
}
675+
676+
void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB,
677+
MachineBasicBlock::iterator MBBI,
678+
LiveRegs PhysLiveRegs, bool IsSave) {
679+
auto *TLI = Subtarget->getTargetLowering();
680+
DebugLoc DL = getDebugLoc(MBB, MBBI);
681+
Register BufferPtr = AArch64::X0;
682+
683+
PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
684+
685+
// Copy the buffer pointer into X0.
686+
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
687+
.addReg(getAgnosticZABufferPtr());
688+
689+
// Call __arm_sme_save/__arm_sme_restore.
690+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
691+
.addReg(BufferPtr, RegState::Implicit)
692+
.addExternalSymbol(TLI->getLibcallName(
693+
IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE))
694+
.addRegMask(TRI->getCallPreservedMask(
695+
*MF,
696+
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
697+
698+
restorePhyRegSave(RegSave, MBB, MBBI, DL);
699+
}
700+
701+
void MachineSMEABI::emitAllocateFullZASaveBuffer(
702+
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
703+
LiveRegs PhysLiveRegs) {
704+
// Buffer already allocated in SelectionDAG.
705+
if (AFI->getEarlyAllocSMESaveBuffer())
706+
return;
707+
708+
DebugLoc DL = getDebugLoc(MBB, MBBI);
709+
Register BufferPtr = getAgnosticZABufferPtr();
710+
Register BufferSize = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
711+
712+
PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
713+
714+
// Calculate the SME state size.
715+
{
716+
auto *TLI = Subtarget->getTargetLowering();
717+
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
718+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
719+
.addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_SME_STATE_SIZE))
720+
.addReg(AArch64::X0, RegState::ImplicitDefine)
721+
.addRegMask(TRI->getCallPreservedMask(
722+
*MF, CallingConv::
723+
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
724+
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferSize)
725+
.addReg(AArch64::X0);
726+
}
727+
728+
// Allocate a buffer object of the size given __arm_sme_state_size.
729+
{
730+
MachineFrameInfo &MFI = MF->getFrameInfo();
731+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::SUBXrx64), AArch64::SP)
732+
.addReg(AArch64::SP)
733+
.addReg(BufferSize)
734+
.addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
735+
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
736+
.addReg(AArch64::SP);
737+
738+
// We have just allocated a variable sized object, tell this to PEI.
739+
MFI.CreateVariableSizedObject(Align(16), nullptr);
740+
}
741+
742+
restorePhyRegSave(RegSave, MBB, MBBI, DL);
743+
}
744+
620745
void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
621746
MachineBasicBlock::iterator InsertPt,
622747
ZAState From, ZAState To,
@@ -634,10 +759,7 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
634759
// TODO: Avoid setting up the save buffer if there's no transition to
635760
// LOCAL_SAVED.
636761
if (From == ZAState::CALLER_DORMANT) {
637-
assert(MBB.getParent()
638-
->getInfo<AArch64FunctionInfo>()
639-
->getSMEFnAttrs()
640-
.hasPrivateZAInterface() &&
762+
assert(AFI->getSMEFnAttrs().hasPrivateZAInterface() &&
641763
"CALLER_DORMANT state requires private ZA interface");
642764
assert(&MBB == &MBB.getParent()->front() &&
643765
"CALLER_DORMANT state only valid in entry block");
@@ -652,12 +774,14 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
652774
}
653775

654776
if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
655-
emitSetupLazySave(MBB, InsertPt);
777+
emitZASave(MBB, InsertPt, PhysLiveRegs);
656778
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
657-
emitRestoreLazySave(MBB, InsertPt, PhysLiveRegs);
779+
emitZARestore(MBB, InsertPt, PhysLiveRegs);
658780
else if (To == ZAState::OFF) {
659781
assert(From != ZAState::CALLER_DORMANT &&
660782
"CALLER_DORMANT to OFF should have already been handled");
783+
assert(!AFI->getSMEFnAttrs().hasAgnosticZAInterface() &&
784+
"Should not turn ZA off in agnostic ZA function");
661785
emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
662786
} else {
663787
dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
@@ -675,9 +799,10 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
675799
if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
676800
return false;
677801

678-
auto *AFI = MF.getInfo<AArch64FunctionInfo>();
802+
AFI = MF.getInfo<AArch64FunctionInfo>();
679803
SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
680-
if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State())
804+
if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State() &&
805+
!SMEFnAttrs.hasAgnosticZAInterface())
681806
return false;
682807

683808
assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
@@ -696,15 +821,18 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
696821
insertStateChanges();
697822

698823
// Allocate save buffer (if needed).
699-
if (State.TPIDR2Block) {
824+
if (State.AgnosticZABufferPtr != AArch64::NoRegister || State.TPIDR2Block) {
700825
if (State.AfterSMEProloguePt) {
701826
// Note: With inline stack probes the AfterSMEProloguePt may not be in the
702827
// entry block (due to the probing loop).
703-
emitAllocateLazySaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
704-
*State.AfterSMEProloguePt);
828+
emitAllocateZASaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
829+
*State.AfterSMEProloguePt,
830+
State.PhysLiveRegsAfterSMEPrologue);
705831
} else {
706832
MachineBasicBlock &EntryBlock = MF.front();
707-
emitAllocateLazySaveBuffer(EntryBlock, EntryBlock.getFirstNonPHI());
833+
emitAllocateZASaveBuffer(
834+
EntryBlock, EntryBlock.getFirstNonPHI(),
835+
State.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry);
708836
}
709837
}
710838

0 commit comments

Comments
 (0)