Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8489,13 +8489,22 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
if (Subtarget->hasCustomCallingConv())
Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);

if (getTM().useNewSMEABILowering() && !Attrs.hasAgnosticZAInterface()) {
if (getTM().useNewSMEABILowering()) {
if (Subtarget->isTargetWindows() || hasInlineStackProbe(MF)) {
SDValue Size;
if (Attrs.hasZAState()) {
SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
DAG.getConstant(1, DL, MVT::i32));
Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
} else if (Attrs.hasAgnosticZAInterface()) {
RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE_SIZE;
SDValue Callee = DAG.getExternalSymbol(
getLibcallName(LC), getPointerTy(DAG.getDataLayout()));
auto *RetTy = EVT(MVT::i64).getTypeForEVT(*DAG.getContext());
TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
getLibcallCallingConv(LC), RetTy, Callee, {});
std::tie(Size, Chain) = LowerCallTo(CLI);
}
if (Size) {
SDValue Buffer = DAG.getNode(
Expand Down Expand Up @@ -8561,7 +8570,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
Register BufferPtr =
MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
FuncInfo->setSMESaveBufferAddr(BufferPtr);
Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
Chain = DAG.getCopyToReg(Buffer.getValue(1), DL, BufferPtr, Buffer);
}
}

Expand Down Expand Up @@ -9300,17 +9309,17 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,

// Determine whether we need any streaming mode changes.
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);

std::optional<unsigned> ZAMarkerNode;
bool UseNewSMEABILowering = getTM().useNewSMEABILowering();
bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
// TODO: Handle agnostic ZA functions.
if (!UseNewSMEABILowering || IsAgnosticZAFunction)
return std::nullopt;
if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
return std::nullopt;
return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
: AArch64ISD::INOUT_ZA_USE;
}();
if (UseNewSMEABILowering) {
if (CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingAllZAState())
ZAMarkerNode = AArch64ISD::REQUIRES_ZA_SAVE;
else if (CallAttrs.caller().hasZAState() ||
CallAttrs.caller().hasZT0State())
ZAMarkerNode = AArch64ISD::INOUT_ZA_USE;
}

if (IsTailCall) {
// Check if it's really possible to do a tail call.
Expand Down Expand Up @@ -9385,7 +9394,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
};

bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
bool RequiresSaveAllZA =
!UseNewSMEABILowering && CallAttrs.requiresPreservingAllZAState();
if (RequiresLazySave) {
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
SDValue TPIDR2ObjAddr = DAG.getFrameIndex(
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
EarlyAllocSMESaveBuffer = Ptr;
}

Register getEarlyAllocSMESaveBuffer() { return EarlyAllocSMESaveBuffer; }
Register getEarlyAllocSMESaveBuffer() const {
return EarlyAllocSMESaveBuffer;
}

// Old SME ABI lowering state getters/setters:
Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
Expand Down
166 changes: 147 additions & 19 deletions llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This pass implements the SME ABI requirements for ZA state. This includes
// implementing the lazy ZA state save schemes around calls.
// implementing the lazy (and agnostic) ZA state save schemes around calls.
//
//===----------------------------------------------------------------------===//
//
Expand Down Expand Up @@ -215,9 +215,44 @@ struct MachineSMEABI : public MachineFunctionPass {
void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
bool ClearTPIDR2);

// Emission routines for agnostic ZA functions.
void emitSetupFullZASave(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs);
// Emit a "full" ZA save or restore. It is "full" in the sense that this
// function will emit a call to __arm_sme_save or __arm_sme_restore, which
// handles saving and restoring both ZA and ZT0.
void emitFullZASaveRestore(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs, bool IsSave);
void emitAllocateFullZASaveBuffer(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs);

void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
ZAState From, ZAState To, LiveRegs PhysLiveRegs);

// Helpers for switching between lazy/full ZA save/restore routines.
void emitZASave(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs) {
if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/true);
return emitSetupLazySave(MBB, MBBI);
}
void emitZARestore(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs) {
if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/false);
return emitRestoreLazySave(MBB, MBBI, PhysLiveRegs);
}
void emitAllocateZASaveBuffer(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs) {
if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
return emitAllocateFullZASaveBuffer(MBB, MBBI, PhysLiveRegs);
return emitAllocateLazySaveBuffer(MBB, MBBI);
}

/// Save live physical registers to virtual registers.
PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI, DebugLoc DL);
Expand All @@ -228,6 +263,8 @@ struct MachineSMEABI : public MachineFunctionPass {
/// Get or create a TPIDR2 block in this function.
TPIDR2State getTPIDR2Block();

Register getAgnosticZABufferPtr();

private:
/// Contains the needed ZA state (and live registers) at an instruction.
struct InstInfo {
Expand All @@ -241,6 +278,7 @@ struct MachineSMEABI : public MachineFunctionPass {
struct BlockInfo {
ZAState FixedEntryState{ZAState::ANY};
SmallVector<InstInfo> Insts;
LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
};

Expand All @@ -250,18 +288,22 @@ struct MachineSMEABI : public MachineFunctionPass {
SmallVector<ZAState> BundleStates;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are starting to accumulate a lot of state, which makes the code harder to follow as it allows any member function to modify it instead of having clear ins/outs.

I know we discussed that already here, but I feel it would be nice not to delay the refactoring too much. Even having a first step that collects all the info in a struct would help. We could then pass that info around by const ref to any function that needs it. If some info needs to be mutable, then it should not be in the struct, and be a clear in/out parameter.

Doing something like this would clearly decouple the "collection" phase from the "let me correctly handle the state changes" phase.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a WIP patch that implements the scheme I mentioned previously: #156674

std::optional<TPIDR2State> TPIDR2Block;
std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
Register AgnosticZABufferPtr = AArch64::NoRegister;
LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
} State;

MachineFunction *MF = nullptr;
EdgeBundles *Bundles = nullptr;
const AArch64Subtarget *Subtarget = nullptr;
const AArch64RegisterInfo *TRI = nullptr;
const AArch64FunctionInfo *AFI = nullptr;
const TargetInstrInfo *TII = nullptr;
MachineRegisterInfo *MRI = nullptr;
};

void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
assert((SMEFnAttrs.hasZT0State() || SMEFnAttrs.hasZAState()) &&
assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
SMEFnAttrs.hasZAState()) &&
"Expected function to have ZA/ZT0 state!");

State.Blocks.resize(MF->getNumBlockIDs());
Expand Down Expand Up @@ -295,6 +337,7 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {

Block.PhysLiveRegsAtExit = GetPhysLiveRegs();
auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
auto FirstNonPhiInsertPt = MBB.getFirstNonPHI();
for (MachineInstr &MI : reverse(MBB)) {
MachineBasicBlock::iterator MBBI(MI);
LiveUnits.stepBackward(MI);
Expand All @@ -303,8 +346,11 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
// buffer was allocated in SelectionDAG. It marks the end of the
// allocation -- which is a safe point for this pass to insert any TPIDR2
// block setup.
if (MI.getOpcode() == AArch64::SMEStateAllocPseudo)
if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
State.AfterSMEProloguePt = MBBI;
State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
}
// Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
auto [NeededState, InsertPt] = getZAStateBeforeInst(
*TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface());
assert((InsertPt == MBBI ||
Expand All @@ -313,6 +359,8 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
// TODO: Do something to avoid state changes where NZCV is live.
if (MBBI == FirstTerminatorInsertPt)
Block.PhysLiveRegsAtExit = PhysLiveRegs;
if (MBBI == FirstNonPhiInsertPt)
Block.PhysLiveRegsAtEntry = PhysLiveRegs;
if (NeededState != ZAState::ANY)
Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs});
}
Expand Down Expand Up @@ -536,8 +584,6 @@ void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
void MachineSMEABI::emitAllocateLazySaveBuffer(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) {
MachineFrameInfo &MFI = MF->getFrameInfo();
auto *AFI = MF->getInfo<AArch64FunctionInfo>();

DebugLoc DL = getDebugLoc(MBB, MBBI);
Register SP = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
Register SVL = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
Expand Down Expand Up @@ -601,8 +647,7 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
.addImm(AArch64SysReg::TPIDR2_EL0);
// If TPIDR2_EL0 is non-zero, commit the lazy save.
// NOTE: Functions that only use ZT0 don't need to zero ZA.
bool ZeroZA =
MF->getInfo<AArch64FunctionInfo>()->getSMEFnAttrs().hasZAState();
bool ZeroZA = AFI->getSMEFnAttrs().hasZAState();
auto CommitZASave =
BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
.addReg(TPIDR2EL0)
Expand All @@ -617,6 +662,86 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
.addImm(1);
}

Register MachineSMEABI::getAgnosticZABufferPtr() {
if (State.AgnosticZABufferPtr != AArch64::NoRegister)
return State.AgnosticZABufferPtr;
Register BufferPtr = AFI->getEarlyAllocSMESaveBuffer();
State.AgnosticZABufferPtr =
BufferPtr != AArch64::NoRegister
? BufferPtr
: MF->getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
return State.AgnosticZABufferPtr;
}

void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs, bool IsSave) {
auto *TLI = Subtarget->getTargetLowering();
DebugLoc DL = getDebugLoc(MBB, MBBI);
Register BufferPtr = AArch64::X0;

PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);

// Copy the buffer pointer into X0.
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
.addReg(getAgnosticZABufferPtr());

// Call __arm_sme_save/__arm_sme_restore.
BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
.addReg(BufferPtr, RegState::Implicit)
.addExternalSymbol(TLI->getLibcallName(
IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE))
.addRegMask(TRI->getCallPreservedMask(
*MF,
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));

restorePhyRegSave(RegSave, MBB, MBBI, DL);
}

void MachineSMEABI::emitAllocateFullZASaveBuffer(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs) {
// Buffer already allocated in SelectionDAG.
if (AFI->getEarlyAllocSMESaveBuffer())
return;

DebugLoc DL = getDebugLoc(MBB, MBBI);
Register BufferPtr = getAgnosticZABufferPtr();
Register BufferSize = MRI->createVirtualRegister(&AArch64::GPR64RegClass);

PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);

// Calculate the SME state size.
{
auto *TLI = Subtarget->getTargetLowering();
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
.addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_SME_STATE_SIZE))
.addReg(AArch64::X0, RegState::ImplicitDefine)
.addRegMask(TRI->getCallPreservedMask(
*MF, CallingConv::
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferSize)
.addReg(AArch64::X0);
}

// Allocate a buffer object of the size given __arm_sme_state_size.
{
MachineFrameInfo &MFI = MF->getFrameInfo();
BuildMI(MBB, MBBI, DL, TII->get(AArch64::SUBXrx64), AArch64::SP)
.addReg(AArch64::SP)
.addReg(BufferSize)
.addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
.addReg(AArch64::SP);

// We have just allocated a variable sized object, tell this to PEI.
MFI.CreateVariableSizedObject(Align(16), nullptr);
}

restorePhyRegSave(RegSave, MBB, MBBI, DL);
}

void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
MachineBasicBlock::iterator InsertPt,
ZAState From, ZAState To,
Expand All @@ -634,10 +759,7 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
// TODO: Avoid setting up the save buffer if there's no transition to
// LOCAL_SAVED.
if (From == ZAState::CALLER_DORMANT) {
assert(MBB.getParent()
->getInfo<AArch64FunctionInfo>()
->getSMEFnAttrs()
.hasPrivateZAInterface() &&
assert(AFI->getSMEFnAttrs().hasPrivateZAInterface() &&
"CALLER_DORMANT state requires private ZA interface");
assert(&MBB == &MBB.getParent()->front() &&
"CALLER_DORMANT state only valid in entry block");
Expand All @@ -652,12 +774,14 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
}

if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
emitSetupLazySave(MBB, InsertPt);
emitZASave(MBB, InsertPt, PhysLiveRegs);
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
emitRestoreLazySave(MBB, InsertPt, PhysLiveRegs);
emitZARestore(MBB, InsertPt, PhysLiveRegs);
else if (To == ZAState::OFF) {
assert(From != ZAState::CALLER_DORMANT &&
"CALLER_DORMANT to OFF should have already been handled");
assert(!AFI->getSMEFnAttrs().hasAgnosticZAInterface() &&
"Should not turn ZA off in agnostic ZA function");
emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
} else {
dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
Expand All @@ -675,9 +799,10 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
return false;

auto *AFI = MF.getInfo<AArch64FunctionInfo>();
AFI = MF.getInfo<AArch64FunctionInfo>();
SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State())
if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State() &&
!SMEFnAttrs.hasAgnosticZAInterface())
return false;

assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
Expand All @@ -696,15 +821,18 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
insertStateChanges();

// Allocate save buffer (if needed).
if (State.TPIDR2Block) {
if (State.AgnosticZABufferPtr != AArch64::NoRegister || State.TPIDR2Block) {
if (State.AfterSMEProloguePt) {
// Note: With inline stack probes the AfterSMEProloguePt may not be in the
// entry block (due to the probing loop).
emitAllocateLazySaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
*State.AfterSMEProloguePt);
emitAllocateZASaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
*State.AfterSMEProloguePt,
State.PhysLiveRegsAfterSMEPrologue);
} else {
MachineBasicBlock &EntryBlock = MF.front();
emitAllocateLazySaveBuffer(EntryBlock, EntryBlock.getFirstNonPHI());
emitAllocateZASaveBuffer(
EntryBlock, EntryBlock.getFirstNonPHI(),
State.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry);
}
}

Expand Down
Loading